mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-01 15:34:29 -07:00
Pull contentGenerator out of GeminiClient and into Config. (#7825)
This commit is contained in:
committed by
GitHub
parent
d33defde68
commit
6e26c88c2c
@@ -40,7 +40,7 @@ export async function createCodeAssistContentGenerator(
|
||||
export function getCodeAssistServer(
|
||||
config: Config,
|
||||
): CodeAssistServer | undefined {
|
||||
let server = config.getGeminiClient().getContentGenerator();
|
||||
let server = config.getContentGenerator();
|
||||
|
||||
// Unwrap LoggingContentGenerator if present
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
|
||||
@@ -71,18 +71,12 @@ vi.mock('../tools/memoryTool', () => ({
|
||||
GEMINI_CONFIG_DIR: '.gemini',
|
||||
}));
|
||||
|
||||
vi.mock('../core/contentGenerator.js', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('../core/contentGenerator.js')>();
|
||||
return {
|
||||
...actual,
|
||||
createContentGeneratorConfig: vi.fn(),
|
||||
};
|
||||
});
|
||||
vi.mock('../core/contentGenerator.js');
|
||||
|
||||
vi.mock('../core/client.js', () => ({
|
||||
GeminiClient: vi.fn().mockImplementation(() => ({
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
stripThoughtsFromHistory: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
@@ -196,7 +190,9 @@ describe('Server Config (config.ts)', () => {
|
||||
apiKey: 'test-key',
|
||||
};
|
||||
|
||||
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
|
||||
vi.mocked(createContentGeneratorConfig).mockReturnValue(
|
||||
mockContentConfig,
|
||||
);
|
||||
|
||||
// Set fallback mode to true to ensure it gets reset
|
||||
config.setFallbackMode(true);
|
||||
@@ -217,172 +213,38 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
});
|
||||
|
||||
it('should preserve conversation history when refreshing auth', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
const mockContentConfig = {
|
||||
model: 'gemini-pro',
|
||||
apiKey: 'test-key',
|
||||
};
|
||||
|
||||
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
|
||||
|
||||
// Mock the existing client with some history
|
||||
const mockExistingHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'Hi there!' }] },
|
||||
{ role: 'user', parts: [{ text: 'How are you?' }] },
|
||||
];
|
||||
|
||||
const mockExistingClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
|
||||
};
|
||||
|
||||
const mockNewClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setHistory: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
// Set the existing client
|
||||
(
|
||||
config as unknown as { geminiClient: typeof mockExistingClient }
|
||||
).geminiClient = mockExistingClient;
|
||||
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
|
||||
|
||||
await config.refreshAuth(authType);
|
||||
|
||||
// Verify that existing history was retrieved
|
||||
expect(mockExistingClient.getHistory).toHaveBeenCalled();
|
||||
|
||||
// Verify that new client was created and initialized
|
||||
expect(GeminiClient).toHaveBeenCalledWith(config);
|
||||
expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig);
|
||||
|
||||
// Verify that history was restored to the new client
|
||||
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
|
||||
mockExistingHistory,
|
||||
{ stripThoughts: false },
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle case when no existing client is initialized', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
const mockContentConfig = {
|
||||
model: 'gemini-pro',
|
||||
apiKey: 'test-key',
|
||||
};
|
||||
|
||||
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
|
||||
|
||||
const mockNewClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setHistory: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
// No existing client
|
||||
(config as unknown as { geminiClient: null }).geminiClient = null;
|
||||
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
|
||||
|
||||
await config.refreshAuth(authType);
|
||||
|
||||
// Verify that new client was created and initialized
|
||||
expect(GeminiClient).toHaveBeenCalledWith(config);
|
||||
expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig);
|
||||
|
||||
// Verify that setHistory was not called since there was no existing history
|
||||
expect(mockNewClient.setHistory).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should strip thoughts when switching from GenAI to Vertex', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const mockContentConfig = {
|
||||
model: 'gemini-pro',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
(
|
||||
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
|
||||
).contentGeneratorConfig = mockContentConfig;
|
||||
|
||||
(createContentGeneratorConfig as Mock).mockReturnValue({
|
||||
...mockContentConfig,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
vi.mocked(createContentGeneratorConfig).mockImplementation(
|
||||
(_: Config, authType: AuthType | undefined) =>
|
||||
({ authType }) as unknown as ContentGeneratorConfig,
|
||||
);
|
||||
|
||||
const mockExistingHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
];
|
||||
const mockExistingClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
|
||||
};
|
||||
const mockNewClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setHistory: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
(
|
||||
config as unknown as { geminiClient: typeof mockExistingClient }
|
||||
).geminiClient = mockExistingClient;
|
||||
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
|
||||
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
|
||||
|
||||
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
|
||||
mockExistingHistory,
|
||||
{ stripThoughts: true },
|
||||
);
|
||||
expect(
|
||||
config.getGeminiClient().stripThoughtsFromHistory,
|
||||
).toHaveBeenCalledWith();
|
||||
});
|
||||
|
||||
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const mockContentConfig = {
|
||||
model: 'gemini-pro',
|
||||
apiKey: 'test-key',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
};
|
||||
(
|
||||
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
|
||||
).contentGeneratorConfig = mockContentConfig;
|
||||
|
||||
(createContentGeneratorConfig as Mock).mockReturnValue({
|
||||
...mockContentConfig,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
});
|
||||
vi.mocked(createContentGeneratorConfig).mockImplementation(
|
||||
(_: Config, authType: AuthType | undefined) =>
|
||||
({ authType }) as unknown as ContentGeneratorConfig,
|
||||
);
|
||||
|
||||
const mockExistingHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
];
|
||||
const mockExistingClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
|
||||
};
|
||||
const mockNewClient = {
|
||||
isInitialized: vi.fn().mockReturnValue(true),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setHistory: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
(
|
||||
config as unknown as { geminiClient: typeof mockExistingClient }
|
||||
).geminiClient = mockExistingClient;
|
||||
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
|
||||
await config.refreshAuth(AuthType.USE_VERTEX_AI);
|
||||
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
|
||||
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
|
||||
mockExistingHistory,
|
||||
{ stripThoughts: false },
|
||||
);
|
||||
expect(
|
||||
config.getGeminiClient().stripThoughtsFromHistory,
|
||||
).not.toHaveBeenCalledWith();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -6,9 +6,13 @@
|
||||
|
||||
import * as path from 'node:path';
|
||||
import process from 'node:process';
|
||||
import type { ContentGeneratorConfig } from '../core/contentGenerator.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
import {
|
||||
AuthType,
|
||||
createContentGenerator,
|
||||
createContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
@@ -44,7 +48,6 @@ import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
|
||||
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
|
||||
import { IdeClient } from '../ide/ide-client.js';
|
||||
import { ideContext } from '../ide/ideContext.js';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { FileSystemService } from '../services/fileSystemService.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
|
||||
@@ -57,6 +60,8 @@ import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { Storage } from './storage.js';
|
||||
import { FileExclusions } from '../utils/ignorePatterns.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
export enum ApprovalMode {
|
||||
DEFAULT = 'default',
|
||||
@@ -226,6 +231,7 @@ export class Config {
|
||||
private readonly sessionId: string;
|
||||
private fileSystemService: FileSystemService;
|
||||
private contentGeneratorConfig!: ContentGeneratorConfig;
|
||||
private contentGenerator!: ContentGenerator;
|
||||
private readonly embeddingModel: string;
|
||||
private readonly sandbox: SandboxConfig | undefined;
|
||||
private readonly targetDir: string;
|
||||
@@ -386,6 +392,11 @@ export class Config {
|
||||
if (this.telemetrySettings.enabled) {
|
||||
initializeTelemetry(this);
|
||||
}
|
||||
|
||||
if (this.getProxy()) {
|
||||
setGlobalDispatcher(new ProxyAgent(this.getProxy() as string));
|
||||
}
|
||||
this.geminiClient = new GeminiClient(this);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -410,46 +421,45 @@ export class Config {
|
||||
this.promptRegistry = new PromptRegistry();
|
||||
this.toolRegistry = await this.createToolRegistry();
|
||||
logCliConfiguration(this, new StartSessionEvent(this, this.toolRegistry));
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
}
|
||||
|
||||
getContentGenerator(): ContentGenerator {
|
||||
return this.contentGenerator;
|
||||
}
|
||||
|
||||
async refreshAuth(authMethod: AuthType) {
|
||||
// Save the current conversation history before creating a new client
|
||||
let existingHistory: Content[] = [];
|
||||
if (this.geminiClient && this.geminiClient.isInitialized()) {
|
||||
existingHistory = this.geminiClient.getHistory();
|
||||
// Vertex and Genai have incompatible encryption and sending history with
|
||||
// throughtSignature from Genai to Vertex will fail, we need to strip them
|
||||
if (
|
||||
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
|
||||
authMethod === AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
// Restore the conversation history to the new client
|
||||
this.geminiClient.stripThoughtsFromHistory();
|
||||
}
|
||||
|
||||
// Create new content generator config
|
||||
const newContentGeneratorConfig = createContentGeneratorConfig(
|
||||
this,
|
||||
authMethod,
|
||||
);
|
||||
|
||||
// Create and initialize new client in local variable first
|
||||
const newGeminiClient = new GeminiClient(this);
|
||||
await newGeminiClient.initialize(newContentGeneratorConfig);
|
||||
|
||||
// Vertex and Genai have incompatible encryption and sending history with
|
||||
// throughtSignature from Genai to Vertex will fail, we need to strip them
|
||||
const fromGenaiToVertex =
|
||||
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
|
||||
authMethod === AuthType.LOGIN_WITH_GOOGLE;
|
||||
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
newContentGeneratorConfig,
|
||||
this,
|
||||
this.getSessionId(),
|
||||
);
|
||||
// Only assign to instance properties after successful initialization
|
||||
this.contentGeneratorConfig = newContentGeneratorConfig;
|
||||
this.geminiClient = newGeminiClient;
|
||||
|
||||
// Restore the conversation history to the new client
|
||||
if (existingHistory.length > 0) {
|
||||
this.geminiClient.setHistory(existingHistory, {
|
||||
stripThoughts: fromGenaiToVertex,
|
||||
});
|
||||
}
|
||||
|
||||
// Reset the session flag since we're explicitly changing auth and using default model
|
||||
this.inFallbackMode = false;
|
||||
}
|
||||
|
||||
getUserTier(): UserTierId | undefined {
|
||||
return this.contentGenerator?.userTier;
|
||||
}
|
||||
|
||||
getSessionId(): string {
|
||||
return this.sessionId;
|
||||
}
|
||||
|
||||
@@ -4,24 +4,9 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
import type {
|
||||
Chat,
|
||||
Content,
|
||||
EmbedContentResponse,
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import {
|
||||
findIndexAfterFraction,
|
||||
isThinkingDefault,
|
||||
@@ -34,7 +19,7 @@ import {
|
||||
type ContentGeneratorConfig,
|
||||
} from './contentGenerator.js';
|
||||
import { type GeminiChat } from './geminiChat.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
CompressionStatus,
|
||||
GeminiEventType,
|
||||
@@ -76,12 +61,8 @@ vi.mock('node:fs', () => {
|
||||
});
|
||||
|
||||
// --- Mocks ---
|
||||
const mockChatCreateFn = vi.fn();
|
||||
const mockGenerateContentFn = vi.fn();
|
||||
const mockEmbedContentFn = vi.fn();
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('./turn', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('./turn.js')>();
|
||||
// Define a mock class that has the same shape as the real Turn
|
||||
@@ -228,6 +209,8 @@ describe('isThinkingDefault', () => {
|
||||
});
|
||||
|
||||
describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
let client: GeminiClient;
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
@@ -235,29 +218,15 @@ describe('Gemini Client (client.ts)', () => {
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
|
||||
// Set up the mock for GoogleGenAI constructor and its methods
|
||||
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
|
||||
MockedGoogleGenAI.mockImplementation(() => {
|
||||
const mock = {
|
||||
chats: { create: mockChatCreateFn },
|
||||
models: {
|
||||
generateContent: mockGenerateContentFn,
|
||||
embedContent: mockEmbedContentFn,
|
||||
},
|
||||
};
|
||||
return mock as unknown as GoogleGenAI;
|
||||
});
|
||||
|
||||
mockChatCreateFn.mockResolvedValue({} as Chat);
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: '{"key": "value"}' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
mockContentGenerator = {
|
||||
generateContent: vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
|
||||
}),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
// Because the GeminiClient constructor kicks off an async process (startChat)
|
||||
// that depends on a fully-formed Config object, we need to mock the
|
||||
@@ -273,7 +242,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
vertexai: false,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
};
|
||||
const mockConfigObject = {
|
||||
mockConfig = {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
@@ -309,46 +278,18 @@ describe('Gemini Client (client.ts)', () => {
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
};
|
||||
const MockedConfig = vi.mocked(Config, true);
|
||||
MockedConfig.mockImplementation(
|
||||
() => mockConfigObject as unknown as Config,
|
||||
);
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
} as unknown as Config;
|
||||
|
||||
// We can instantiate the client here since Config is mocked
|
||||
// and the constructor will use the mocked GoogleGenAI
|
||||
client = new GeminiClient(
|
||||
new Config({ sessionId: 'test-session-id' } as never),
|
||||
);
|
||||
mockConfigObject.getGeminiClient.mockReturnValue(client);
|
||||
|
||||
await client.initialize(contentGeneratorConfig);
|
||||
client = new GeminiClient(mockConfig);
|
||||
await client.initialize();
|
||||
vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
// NOTE: The following tests for startChat were removed due to persistent issues with
|
||||
// the @google/genai mock. Specifically, the mockChatCreateFn (representing instance.chats.create)
|
||||
// was not being detected as called by the GeminiClient instance.
|
||||
// This likely points to a subtle issue in how the GoogleGenerativeAI class constructor
|
||||
// and its instance methods are mocked and then used by the class under test.
|
||||
// For future debugging, ensure that the `this.client` in `GeminiClient` (which is an
|
||||
// instance of the mocked GoogleGenerativeAI) correctly has its `chats.create` method
|
||||
// pointing to `mockChatCreateFn`.
|
||||
// it('startChat should call getCoreSystemPrompt with userMemory and pass to chats.create', async () => { ... });
|
||||
// it('startChat should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
|
||||
|
||||
// NOTE: The following tests for generateJson were removed due to persistent issues with
|
||||
// the @google/genai mock, similar to the startChat tests. The mockGenerateContentFn
|
||||
// (representing instance.models.generateContent) was not being detected as called, or the mock
|
||||
// was not preventing an actual API call (leading to API key errors).
|
||||
// For future debugging, ensure `this.client.models.generateContent` in `GeminiClient` correctly
|
||||
// uses the `mockGenerateContentFn`.
|
||||
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
|
||||
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
|
||||
|
||||
describe('generateEmbedding', () => {
|
||||
const texts = ['hello world', 'goodbye world'];
|
||||
const testEmbeddingModel = 'test-embedding-model';
|
||||
@@ -358,18 +299,17 @@ describe('Gemini Client (client.ts)', () => {
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
];
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [
|
||||
{ values: mockEmbeddings[0] },
|
||||
{ values: mockEmbeddings[1] },
|
||||
],
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
const result = await client.generateEmbedding(texts);
|
||||
|
||||
expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
|
||||
expect(mockEmbedContentFn).toHaveBeenCalledWith({
|
||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({
|
||||
model: testEmbeddingModel,
|
||||
contents: texts,
|
||||
});
|
||||
@@ -379,11 +319,11 @@ describe('Gemini Client (client.ts)', () => {
|
||||
it('should return an empty array if an empty array is passed', async () => {
|
||||
const result = await client.generateEmbedding([]);
|
||||
expect(result).toEqual([]);
|
||||
expect(mockEmbedContentFn).not.toHaveBeenCalled();
|
||||
expect(mockContentGenerator.embedContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if API response has no embeddings array', async () => {
|
||||
mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
@@ -391,20 +331,19 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [],
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||
@@ -412,10 +351,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has nullish values', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||
@@ -423,10 +361,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has an empty values array', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||
@@ -434,8 +371,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should propagate errors from the API call', async () => {
|
||||
const apiError = new Error('API Failure');
|
||||
mockEmbedContentFn.mockRejectedValue(apiError);
|
||||
vi.mocked(mockContentGenerator.embedContent).mockRejectedValue(
|
||||
new Error('API Failure'),
|
||||
);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API Failure',
|
||||
@@ -449,12 +387,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const schema = { type: 'string' };
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
// Mock countTokens
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1,
|
||||
});
|
||||
|
||||
await client.generateJson(
|
||||
contents,
|
||||
@@ -463,7 +398,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
config: {
|
||||
@@ -489,11 +424,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const customModel = 'custom-json-model';
|
||||
const customConfig = { temperature: 0.9, topK: 20 };
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1,
|
||||
});
|
||||
|
||||
await client.generateJson(
|
||||
contents,
|
||||
@@ -503,7 +436,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
customConfig,
|
||||
);
|
||||
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: customModel,
|
||||
config: {
|
||||
@@ -524,10 +457,10 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
describe('addHistory', () => {
|
||||
it('should call chat.addHistory with the provided content', async () => {
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
const mockChat = {
|
||||
addHistory: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
} as unknown as GeminiChat;
|
||||
client['chat'] = mockChat;
|
||||
|
||||
const newContent = {
|
||||
role: 'user',
|
||||
@@ -568,7 +501,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
describe('tryCompressChat', () => {
|
||||
const mockCountTokens = vi.fn();
|
||||
const mockSendMessage = vi.fn();
|
||||
const mockGetHistory = vi.fn();
|
||||
|
||||
@@ -577,10 +509,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
tokenLimit: vi.fn(),
|
||||
}));
|
||||
|
||||
client['contentGenerator'] = {
|
||||
countTokens: mockCountTokens,
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
client['chat'] = {
|
||||
getHistory: mockGetHistory,
|
||||
addHistory: vi.fn(),
|
||||
@@ -600,27 +528,21 @@ describe('Gemini Client (client.ts)', () => {
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
|
||||
};
|
||||
const mockCountTokens = vi
|
||||
.fn()
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: 1000 })
|
||||
.mockResolvedValueOnce({ totalTokens: 5000 });
|
||||
|
||||
const mockGenerator: Partial<Mocked<ContentGenerator>> = {
|
||||
countTokens: mockCountTokens,
|
||||
};
|
||||
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
|
||||
|
||||
return { client, mockChat, mockGenerator };
|
||||
return { client, mockChat };
|
||||
}
|
||||
|
||||
describe('when compression inflates the token count', () => {
|
||||
it('uses the truncated history for compression');
|
||||
it('allows compression to be forced/manual after a failure', async () => {
|
||||
const { client, mockGenerator } = setup();
|
||||
mockGenerator.countTokens?.mockResolvedValue({
|
||||
const { client } = setup();
|
||||
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1000,
|
||||
});
|
||||
await client.tryCompressChat('prompt-id-4'); // Fails
|
||||
@@ -635,6 +557,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
it('yields the result even if the compression inflated the tokens', async () => {
|
||||
const { client } = setup();
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1000,
|
||||
});
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
|
||||
expect(result).toEqual({
|
||||
@@ -654,7 +579,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
it('restores the history back to the original', async () => {
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
mockCountTokens.mockResolvedValue({
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 999,
|
||||
});
|
||||
|
||||
@@ -679,13 +604,13 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('will not attempt to compress context after a failure', async () => {
|
||||
const { client, mockGenerator } = setup();
|
||||
const { client } = setup();
|
||||
await client.tryCompressChat('prompt-id-4');
|
||||
|
||||
const result = await client.tryCompressChat('prompt-id-5');
|
||||
|
||||
// it counts tokens for {original, compressed} and then never again
|
||||
expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
newTokenCount: 0,
|
||||
@@ -696,7 +621,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
it('attempts to compress with a maxOutputTokens set to the original token count', async () => {
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
mockCountTokens.mockResolvedValue({
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 999,
|
||||
});
|
||||
|
||||
@@ -728,8 +653,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
mockGetHistory.mockReturnValue([
|
||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||
]);
|
||||
|
||||
mockCountTokens.mockResolvedValue({
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
|
||||
});
|
||||
|
||||
@@ -763,7 +687,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
||||
const newTokenCount = 100;
|
||||
|
||||
mockCountTokens
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
@@ -800,7 +724,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
||||
const newTokenCount = 100;
|
||||
|
||||
mockCountTokens
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
@@ -853,7 +777,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const originalTokenCount = 1000 * 0.7;
|
||||
const newTokenCount = 100;
|
||||
|
||||
mockCountTokens
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
@@ -895,7 +819,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const originalTokenCount = 10; // Well below threshold
|
||||
const newTokenCount = 5;
|
||||
|
||||
mockCountTokens
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount })
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount });
|
||||
|
||||
@@ -922,10 +846,16 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should use current model from config for token counting after sendMessage', async () => {
|
||||
const initialModel = client['config'].getModel();
|
||||
const initialModel = mockConfig.getModel();
|
||||
|
||||
const mockCountTokens = vi
|
||||
.fn()
|
||||
// mock the model has been changed between calls of `countTokens`
|
||||
const firstCurrentModel = initialModel + '-changed-1';
|
||||
const secondCurrentModel = initialModel + '-changed-2';
|
||||
vi.mocked(mockConfig.getModel)
|
||||
.mockReturnValueOnce(firstCurrentModel)
|
||||
.mockReturnValueOnce(secondCurrentModel);
|
||||
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: 100000 })
|
||||
.mockResolvedValueOnce({ totalTokens: 5000 });
|
||||
|
||||
@@ -936,35 +866,23 @@ describe('Gemini Client (client.ts)', () => {
|
||||
{ role: 'model', parts: [{ text: 'Long response' }] },
|
||||
];
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
const mockChat = {
|
||||
getHistory: vi.fn().mockReturnValue(mockChatHistory),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: mockSendMessage,
|
||||
};
|
||||
} as unknown as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: mockCountTokens,
|
||||
};
|
||||
|
||||
// mock the model has been changed between calls of `countTokens`
|
||||
const firstCurrentModel = initialModel + '-changed-1';
|
||||
const secondCurrentModel = initialModel + '-changed-2';
|
||||
vi.spyOn(client['config'], 'getModel')
|
||||
.mockReturnValueOnce(firstCurrentModel)
|
||||
.mockReturnValueOnce(secondCurrentModel);
|
||||
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
client['chat'] = mockChat;
|
||||
client['startChat'] = vi.fn().mockResolvedValue(mockChat);
|
||||
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
|
||||
expect(mockCountTokens).toHaveBeenCalledTimes(2);
|
||||
expect(mockCountTokens).toHaveBeenNthCalledWith(1, {
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(1, {
|
||||
model: firstCurrentModel,
|
||||
contents: mockChatHistory,
|
||||
});
|
||||
expect(mockCountTokens).toHaveBeenNthCalledWith(2, {
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(2, {
|
||||
model: secondCurrentModel,
|
||||
contents: expect.any(Array),
|
||||
});
|
||||
@@ -980,22 +898,11 @@ describe('Gemini Client (client.ts)', () => {
|
||||
describe('sendMessageStream', () => {
|
||||
it('emits a compression event when the context was automatically compressed', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
mockTurnRunFn.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const compressionInfo: ChatCompressionInfo = {
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
@@ -1042,18 +949,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const compressionInfo: ChatCompressionInfo = {
|
||||
compressionStatus,
|
||||
originalTokenCount: 1000,
|
||||
@@ -1105,24 +1000,19 @@ describe('Gemini Client (client.ts)', () => {
|
||||
},
|
||||
});
|
||||
|
||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true);
|
||||
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
mockTurnRunFn.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
const mockChat = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
} as unknown as GeminiChat;
|
||||
client['chat'] = mockChat;
|
||||
|
||||
const initialRequest: Part[] = [{ text: 'Hi' }];
|
||||
|
||||
@@ -1186,12 +1076,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
|
||||
// Act
|
||||
@@ -1241,12 +1125,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
|
||||
// Act
|
||||
@@ -1317,12 +1195,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
|
||||
// Act
|
||||
@@ -1369,12 +1241,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
@@ -1419,12 +1285,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Use a signal that never gets aborted
|
||||
const abortController = new AbortController();
|
||||
const signal = abortController.signal;
|
||||
@@ -1512,12 +1372,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act & Assert
|
||||
// Run up to the limit
|
||||
for (let i = 0; i < MAX_SESSION_TURNS; i++) {
|
||||
@@ -1574,12 +1428,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Use a signal that never gets aborted
|
||||
const abortController = new AbortController();
|
||||
const signal = abortController.signal;
|
||||
@@ -1659,12 +1507,6 @@ ${JSON.stringify(
|
||||
]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
});
|
||||
|
||||
const testCases = [
|
||||
@@ -1921,11 +1763,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||
vi.mocked(ideContext.getIdeContext).mockReturnValue({
|
||||
workspaceState: {
|
||||
@@ -2266,12 +2103,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
@@ -2308,12 +2139,6 @@ ${JSON.stringify(
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
@@ -2335,13 +2160,6 @@ ${JSON.stringify(
|
||||
const generationConfig = { temperature: 0.5 };
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
// Mock countTokens
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
await client.generateContent(
|
||||
contents,
|
||||
generationConfig,
|
||||
@@ -2349,7 +2167,7 @@ ${JSON.stringify(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
config: {
|
||||
@@ -2371,12 +2189,6 @@ ${JSON.stringify(
|
||||
|
||||
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
await client.generateContent(
|
||||
contents,
|
||||
{},
|
||||
@@ -2384,12 +2196,12 @@ ${JSON.stringify(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
|
||||
expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({
|
||||
model: initialModel,
|
||||
config: expect.any(Object),
|
||||
contents,
|
||||
});
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
config: expect.any(Object),
|
||||
@@ -2427,73 +2239,4 @@ ${JSON.stringify(
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setHistory', () => {
|
||||
it('should strip thought signatures when stripThoughts is true', () => {
|
||||
const mockChat = {
|
||||
setHistory: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as unknown as GeminiChat;
|
||||
|
||||
const historyWithThoughts: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'thinking...', thoughtSignature: 'thought-123' },
|
||||
{
|
||||
functionCall: { name: 'test', args: {} },
|
||||
thoughtSignature: 'thought-456',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
client.setHistory(historyWithThoughts, { stripThoughts: true });
|
||||
|
||||
const expectedHistory: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'thinking...' },
|
||||
{ functionCall: { name: 'test', args: {} } },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
expect(mockChat.setHistory).toHaveBeenCalledWith(expectedHistory);
|
||||
});
|
||||
|
||||
it('should not strip thought signatures when stripThoughts is false', () => {
|
||||
const mockChat = {
|
||||
setHistory: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as unknown as GeminiChat;
|
||||
|
||||
const historyWithThoughts: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'thinking...', thoughtSignature: 'thought-123' },
|
||||
{ text: 'ok', thoughtSignature: 'thought-456' },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
client.setHistory(historyWithThoughts, { stripThoughts: false });
|
||||
|
||||
expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -20,7 +20,6 @@ import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
|
||||
import { CompressionStatus } from './turn.js';
|
||||
import { Turn, GeminiEventType } from './turn.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||
@@ -31,12 +30,8 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
} from './contentGenerator.js';
|
||||
import { AuthType, createContentGenerator } from './contentGenerator.js';
|
||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_THINKING_MODE,
|
||||
@@ -115,8 +110,6 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private contentGenerator?: ContentGenerator;
|
||||
private readonly embeddingModel: string;
|
||||
private readonly generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
@@ -135,33 +128,19 @@ export class GeminiClient {
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
if (config.getProxy()) {
|
||||
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
||||
}
|
||||
|
||||
this.embeddingModel = config.getEmbeddingModel();
|
||||
this.loopDetector = new LoopDetectionService(config);
|
||||
this.lastPromptId = this.config.getSessionId();
|
||||
}
|
||||
|
||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
this.config,
|
||||
this.config.getSessionId(),
|
||||
);
|
||||
async initialize() {
|
||||
this.chat = await this.startChat();
|
||||
}
|
||||
|
||||
getContentGenerator(): ContentGenerator {
|
||||
if (!this.contentGenerator) {
|
||||
private getContentGeneratorOrFail(): ContentGenerator {
|
||||
if (!this.config.getContentGenerator()) {
|
||||
throw new Error('Content generator not initialized');
|
||||
}
|
||||
return this.contentGenerator;
|
||||
}
|
||||
|
||||
getUserTier(): UserTierId | undefined {
|
||||
return this.contentGenerator?.userTier;
|
||||
return this.config.getContentGenerator();
|
||||
}
|
||||
|
||||
async addHistory(content: Content) {
|
||||
@@ -176,39 +155,19 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
isInitialized(): boolean {
|
||||
return this.chat !== undefined && this.contentGenerator !== undefined;
|
||||
return this.chat !== undefined;
|
||||
}
|
||||
|
||||
getHistory(): Content[] {
|
||||
return this.getChat().getHistory();
|
||||
}
|
||||
|
||||
setHistory(
|
||||
history: Content[],
|
||||
{ stripThoughts = false }: { stripThoughts?: boolean } = {},
|
||||
) {
|
||||
const historyToSet = stripThoughts
|
||||
? history.map((content) => {
|
||||
const newContent = { ...content };
|
||||
if (newContent.parts) {
|
||||
newContent.parts = newContent.parts.map((part) => {
|
||||
if (
|
||||
part &&
|
||||
typeof part === 'object' &&
|
||||
'thoughtSignature' in part
|
||||
) {
|
||||
const newPart = { ...part };
|
||||
delete (newPart as { thoughtSignature?: string })
|
||||
.thoughtSignature;
|
||||
return newPart;
|
||||
}
|
||||
return part;
|
||||
});
|
||||
}
|
||||
return newContent;
|
||||
})
|
||||
: history;
|
||||
this.getChat().setHistory(historyToSet);
|
||||
stripThoughtsFromHistory() {
|
||||
this.getChat().stripThoughtsFromHistory();
|
||||
}
|
||||
|
||||
setHistory(history: Content[]) {
|
||||
this.getChat().setHistory(history);
|
||||
this.forceFullIdeContext = true;
|
||||
}
|
||||
|
||||
@@ -242,9 +201,11 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = true;
|
||||
this.hasFailedCompressionAttempt = false;
|
||||
const envParts = await getEnvironmentContext(this.config);
|
||||
|
||||
const toolRegistry = this.config.getToolRegistry();
|
||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||
|
||||
const history: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -274,7 +235,6 @@ export class GeminiClient {
|
||||
: this.generateContentConfig;
|
||||
return new GeminiChat(
|
||||
this.config,
|
||||
this.getContentGenerator(),
|
||||
{
|
||||
systemInstruction,
|
||||
...generateContentConfigWithThinking,
|
||||
@@ -600,7 +560,7 @@ export class GeminiClient {
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGenerator().generateContent(
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
config: {
|
||||
@@ -711,7 +671,7 @@ export class GeminiClient {
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGenerator().generateContent(
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
config: requestConfig,
|
||||
@@ -751,12 +711,12 @@ export class GeminiClient {
|
||||
return [];
|
||||
}
|
||||
const embedModelParams: EmbedContentParameters = {
|
||||
model: this.embeddingModel,
|
||||
model: this.config.getEmbeddingModel(),
|
||||
contents: texts,
|
||||
};
|
||||
|
||||
const embedContentResponse =
|
||||
await this.getContentGenerator().embedContent(embedModelParams);
|
||||
await this.getContentGeneratorOrFail().embedContent(embedModelParams);
|
||||
if (
|
||||
!embedContentResponse.embeddings ||
|
||||
embedContentResponse.embeddings.length === 0
|
||||
@@ -802,7 +762,7 @@ export class GeminiClient {
|
||||
const model = this.config.getModel();
|
||||
|
||||
const { totalTokens: originalTokenCount } =
|
||||
await this.getContentGenerator().countTokens({
|
||||
await this.getContentGeneratorOrFail().countTokens({
|
||||
model,
|
||||
contents: curatedHistory,
|
||||
});
|
||||
@@ -877,7 +837,7 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = true;
|
||||
|
||||
const { totalTokens: newTokenCount } =
|
||||
await this.getContentGenerator().countTokens({
|
||||
await this.getContentGeneratorOrFail().countTokens({
|
||||
// model might change after calling `sendMessage`, so we get the newest value from config
|
||||
model: this.config.getModel(),
|
||||
contents: chat.getHistory(),
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type {
|
||||
Content,
|
||||
Models,
|
||||
GenerateContentConfig,
|
||||
Part,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import {
|
||||
GeminiChat,
|
||||
EmptyStreamError,
|
||||
@@ -47,15 +47,6 @@ vi.mock('node:fs', () => {
|
||||
};
|
||||
});
|
||||
|
||||
// Mocks
|
||||
const mockModelsModule = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as Models;
|
||||
|
||||
const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } =
|
||||
vi.hoisted(() => ({
|
||||
mockLogInvalidChunk: vi.fn(),
|
||||
@@ -70,12 +61,21 @@ vi.mock('../telemetry/loggers.js', () => ({
|
||||
}));
|
||||
|
||||
describe('GeminiChat', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let chat: GeminiChat;
|
||||
let mockConfig: Config;
|
||||
const config: GenerateContentConfig = {};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockContentGenerator = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getTelemetryLogPromptsEnabled: () => true,
|
||||
@@ -97,12 +97,13 @@ describe('GeminiChat', () => {
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn(),
|
||||
}),
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
} as unknown as Config;
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
// Reset history for each test by creating a new instance
|
||||
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
|
||||
chat = new GeminiChat(mockConfig, config, []);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -160,7 +161,7 @@ describe('GeminiChat', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockAfcResponse,
|
||||
);
|
||||
|
||||
@@ -218,7 +219,7 @@ describe('GeminiChat', () => {
|
||||
{ content: { role: 'model', parts: [{ text: 'some response' }] } },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
@@ -247,7 +248,7 @@ describe('GeminiChat', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mixedContentResponse,
|
||||
);
|
||||
|
||||
@@ -302,7 +303,7 @@ describe('GeminiChat', () => {
|
||||
{ content: { role: 'model', parts: [{ thought: true }] } },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
emptyModelResponse,
|
||||
);
|
||||
|
||||
@@ -348,11 +349,13 @@ describe('GeminiChat', () => {
|
||||
],
|
||||
text: () => 'response',
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(response);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
response,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
||||
|
||||
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
@@ -390,7 +393,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
streamWithToolCall,
|
||||
);
|
||||
|
||||
@@ -442,7 +445,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
streamWithNoFinish,
|
||||
);
|
||||
|
||||
@@ -487,7 +490,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
streamWithInvalidEnd,
|
||||
);
|
||||
|
||||
@@ -543,7 +546,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
multiChunkStream,
|
||||
);
|
||||
|
||||
@@ -593,7 +596,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
multiChunkStream,
|
||||
);
|
||||
|
||||
@@ -650,7 +653,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
multiChunkStream,
|
||||
);
|
||||
|
||||
@@ -697,7 +700,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
mixedContentStream,
|
||||
);
|
||||
|
||||
@@ -759,7 +762,7 @@ describe('GeminiChat', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
emptyStreamResponse,
|
||||
);
|
||||
|
||||
@@ -811,7 +814,7 @@ describe('GeminiChat', () => {
|
||||
text: () => 'response',
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
response,
|
||||
);
|
||||
|
||||
@@ -823,7 +826,7 @@ describe('GeminiChat', () => {
|
||||
// consume stream to trigger internal logic
|
||||
}
|
||||
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
{
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
@@ -1012,7 +1015,7 @@ describe('GeminiChat', () => {
|
||||
describe('sendMessageStream with retries', () => {
|
||||
it('should yield a RETRY event when an invalid stream is encountered', async () => {
|
||||
// ARRANGE: Mock the stream to fail once, then succeed.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: An invalid stream with an empty text part.
|
||||
(async function* () {
|
||||
@@ -1053,7 +1056,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
it('should retry on invalid content, succeed, and report metrics', async () => {
|
||||
// Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First call returns an invalid stream
|
||||
(async function* () {
|
||||
@@ -1089,7 +1092,9 @@ describe('GeminiChat', () => {
|
||||
expect(mockLogInvalidChunk).toHaveBeenCalledTimes(1);
|
||||
expect(mockLogContentRetry).toHaveBeenCalledTimes(1);
|
||||
expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
2,
|
||||
);
|
||||
|
||||
// Check for a retry event
|
||||
expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true);
|
||||
@@ -1118,7 +1123,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should fail after all retries on persistent invalid content and report metrics', async () => {
|
||||
vi.mocked(mockModelsModule.generateContentStream).mockImplementation(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
|
||||
async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
@@ -1150,7 +1155,9 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
// Should be called 3 times (initial + 2 retries)
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
3,
|
||||
);
|
||||
expect(mockLogInvalidChunk).toHaveBeenCalledTimes(3);
|
||||
expect(mockLogContentRetry).toHaveBeenCalledTimes(2);
|
||||
expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1);
|
||||
@@ -1169,7 +1176,7 @@ describe('GeminiChat', () => {
|
||||
chat.setHistory(initialHistory);
|
||||
|
||||
// 2. Mock the API to fail once with an empty stream, then succeed.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
@@ -1264,7 +1271,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
// 2. Mock the API to return our controllable promises in order
|
||||
vi.mocked(mockModelsModule.generateContent)
|
||||
vi.mocked(mockContentGenerator.generateContent)
|
||||
.mockReturnValueOnce(firstCallPromise)
|
||||
.mockReturnValueOnce(secondCallPromise);
|
||||
|
||||
@@ -1285,8 +1292,8 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 5. CRUCIAL CHECK: At this point, only the first API call should have been made.
|
||||
// The second call should be waiting on `sendPromise`.
|
||||
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: expect.arrayContaining([
|
||||
expect.objectContaining({ parts: [{ text: 'first' }] }),
|
||||
@@ -1303,8 +1310,8 @@ describe('GeminiChat', () => {
|
||||
await new Promise(process.nextTick);
|
||||
|
||||
// 7. CRUCIAL CHECK: Now, the second API call should have been made.
|
||||
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: expect.arrayContaining([
|
||||
expect.objectContaining({ parts: [{ text: 'second' }] }),
|
||||
@@ -1320,7 +1327,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
it('should retry if the model returns a completely empty stream (no chunks)', async () => {
|
||||
// 1. Mock the API to return an empty stream first, then a valid one.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(
|
||||
// First call resolves to an async generator that yields nothing.
|
||||
async () => (async function* () {})(),
|
||||
@@ -1353,7 +1360,7 @@ describe('GeminiChat', () => {
|
||||
}
|
||||
|
||||
// 3. Assert the results.
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(
|
||||
chunks.some(
|
||||
(c) =>
|
||||
@@ -1417,7 +1424,7 @@ describe('GeminiChat', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockResolvedValueOnce(firstStreamGenerator)
|
||||
.mockResolvedValueOnce(secondStreamGenerator);
|
||||
|
||||
@@ -1436,7 +1443,7 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
// 5. Assert that only one API call has been made so far.
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1);
|
||||
|
||||
// 6. Unblock and fully consume the first stream to completion.
|
||||
continueFirstStream!();
|
||||
@@ -1451,7 +1458,7 @@ describe('GeminiChat', () => {
|
||||
await secondStreamIterator.next();
|
||||
|
||||
// 9. The second API call should now have been made.
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// 10. FIX: Fully consume the second stream to ensure recordHistory is called.
|
||||
await secondStreamIterator.next(); // This finishes the iterator.
|
||||
@@ -1471,7 +1478,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
it('should discard valid partial content from a failed attempt upon retry', async () => {
|
||||
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: yields one valid chunk, then one invalid chunk
|
||||
(async function* () {
|
||||
@@ -1517,7 +1524,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// ASSERT
|
||||
// Check that a retry happened
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
|
||||
|
||||
// Check the final recorded history
|
||||
@@ -1532,4 +1539,41 @@ describe('GeminiChat', () => {
|
||||
'This valid part should be discarded',
|
||||
);
|
||||
});
|
||||
|
||||
describe('stripThoughtsFromHistory', () => {
|
||||
it('should strip thought signatures', () => {
|
||||
chat.setHistory([
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'thinking...', thoughtSignature: 'thought-123' },
|
||||
{
|
||||
functionCall: { name: 'test', args: {} },
|
||||
thoughtSignature: 'thought-456',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
chat.stripThoughtsFromHistory();
|
||||
|
||||
expect(chat.getHistory()).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'thinking...' },
|
||||
{ functionCall: { name: 'test', args: {} } },
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,7 +18,6 @@ import type {
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
@@ -172,7 +171,6 @@ export class GeminiChat {
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly generationConfig: GenerateContentConfig = {},
|
||||
private history: Content[] = [],
|
||||
) {
|
||||
@@ -287,7 +285,7 @@ export class GeminiChat {
|
||||
);
|
||||
}
|
||||
|
||||
return this.contentGenerator.generateContent(
|
||||
return this.config.getContentGenerator().generateContent(
|
||||
{
|
||||
model: modelToUse,
|
||||
contents: requestContents,
|
||||
@@ -498,7 +496,7 @@ export class GeminiChat {
|
||||
);
|
||||
}
|
||||
|
||||
return this.contentGenerator.generateContentStream(
|
||||
return this.config.getContentGenerator().generateContentStream(
|
||||
{
|
||||
model: modelToUse,
|
||||
contents: requestContents,
|
||||
@@ -570,10 +568,28 @@ export class GeminiChat {
|
||||
addHistory(content: Content): void {
|
||||
this.history.push(content);
|
||||
}
|
||||
|
||||
setHistory(history: Content[]): void {
|
||||
this.history = history;
|
||||
}
|
||||
|
||||
stripThoughtsFromHistory(): void {
|
||||
this.history = this.history.map((content) => {
|
||||
const newContent = { ...content };
|
||||
if (newContent.parts) {
|
||||
newContent.parts = newContent.parts.map((part) => {
|
||||
if (part && typeof part === 'object' && 'thoughtSignature' in part) {
|
||||
const newPart = { ...part };
|
||||
delete (newPart as { thoughtSignature?: string }).thoughtSignature;
|
||||
return newPart;
|
||||
}
|
||||
return part;
|
||||
});
|
||||
}
|
||||
return newContent;
|
||||
});
|
||||
}
|
||||
|
||||
setTools(tools: Tool[]): void {
|
||||
this.generationConfig.tools = tools;
|
||||
}
|
||||
|
||||
@@ -55,8 +55,6 @@ async function createMockConfig(
|
||||
};
|
||||
const config = new Config(configParams);
|
||||
await config.initialize();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
await config.refreshAuth('test-auth' as any);
|
||||
|
||||
// Mock ToolRegistry
|
||||
const mockToolRegistry = {
|
||||
@@ -164,15 +162,13 @@ describe('subagent.ts', () => {
|
||||
// Helper to safely access generationConfig from mock calls
|
||||
const getGenerationConfigFromMock = (
|
||||
callIndex = 0,
|
||||
): GenerateContentConfig & { systemInstruction?: string | Content } => {
|
||||
): GenerateContentConfig => {
|
||||
const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex];
|
||||
const generationConfig = callArgs?.[2];
|
||||
const generationConfig = callArgs?.[1];
|
||||
// Ensure it's defined before proceeding
|
||||
expect(generationConfig).toBeDefined();
|
||||
if (!generationConfig) throw new Error('generationConfig is undefined');
|
||||
return generationConfig as GenerateContentConfig & {
|
||||
systemInstruction?: string | Content;
|
||||
};
|
||||
return generationConfig as GenerateContentConfig;
|
||||
};
|
||||
|
||||
describe('create (Tool Validation)', () => {
|
||||
@@ -347,7 +343,7 @@ describe('subagent.ts', () => {
|
||||
);
|
||||
|
||||
// Check History (should include environment context)
|
||||
const history = callArgs[3];
|
||||
const history = callArgs[2];
|
||||
expect(history).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Env Context' }] },
|
||||
{
|
||||
@@ -420,7 +416,7 @@ describe('subagent.ts', () => {
|
||||
|
||||
const callArgs = vi.mocked(GeminiChat).mock.calls[0];
|
||||
const generationConfig = getGenerationConfigFromMock();
|
||||
const history = callArgs[3];
|
||||
const history = callArgs[2];
|
||||
|
||||
expect(generationConfig.systemInstruction).toBeUndefined();
|
||||
expect(history).toEqual([
|
||||
|
||||
@@ -10,7 +10,6 @@ import type { AnyDeclarativeTool } from '../tools/tools.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ToolCallRequestInfo } from './turn.js';
|
||||
import { executeToolCall } from './nonInteractiveToolExecutor.js';
|
||||
import { createContentGenerator } from './contentGenerator.js';
|
||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
||||
import type {
|
||||
Content,
|
||||
@@ -635,9 +634,7 @@ export class SubAgentScope {
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const generationConfig: GenerateContentConfig & {
|
||||
systemInstruction?: string | Content;
|
||||
} = {
|
||||
const generationConfig: GenerateContentConfig = {
|
||||
temperature: this.modelConfig.temp,
|
||||
topP: this.modelConfig.top_p,
|
||||
};
|
||||
@@ -646,17 +643,10 @@ export class SubAgentScope {
|
||||
generationConfig.systemInstruction = systemInstruction;
|
||||
}
|
||||
|
||||
const contentGenerator = await createContentGenerator(
|
||||
this.runtimeContext.getContentGeneratorConfig(),
|
||||
this.runtimeContext,
|
||||
this.runtimeContext.getSessionId(),
|
||||
);
|
||||
|
||||
this.runtimeContext.setModel(this.modelConfig.model);
|
||||
|
||||
return new GeminiChat(
|
||||
this.runtimeContext,
|
||||
contentGenerator,
|
||||
generationConfig,
|
||||
start_history,
|
||||
);
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
|
||||
import type { Mock } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { Content, GoogleGenAI, Models } from '@google/genai';
|
||||
import type { Content } from '@google/genai';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { NextSpeakerResponse } from './nextSpeakerChecker.js';
|
||||
import { checkNextSpeaker } from './nextSpeakerChecker.js';
|
||||
import { GeminiChat } from '../core/geminiChat.js';
|
||||
@@ -44,73 +44,28 @@ vi.mock('node:fs', () => {
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../config/config.js');
|
||||
|
||||
// Define mocks for GoogleGenAI and Models instances that will be used across tests
|
||||
const mockModelsInstance = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as Models;
|
||||
|
||||
const mockGoogleGenAIInstance = {
|
||||
getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance),
|
||||
// Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods
|
||||
} as unknown as GoogleGenAI;
|
||||
|
||||
vi.mock('@google/genai', async () => {
|
||||
const actualGenAI =
|
||||
await vi.importActual<typeof import('@google/genai')>('@google/genai');
|
||||
return {
|
||||
...actualGenAI,
|
||||
GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance
|
||||
// If Models is instantiated directly in GeminiChat, mock its constructor too
|
||||
// For now, assuming Models instance is obtained via getGenerativeModel
|
||||
};
|
||||
});
|
||||
|
||||
describe('checkNextSpeaker', () => {
|
||||
let chatInstance: GeminiChat;
|
||||
let mockConfig: Config;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let MockConfig: Mock;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
MockConfig = vi.mocked(Config);
|
||||
const mockConfigInstance = new MockConfig(
|
||||
'test-api-key',
|
||||
'gemini-pro',
|
||||
false,
|
||||
'.',
|
||||
false,
|
||||
undefined,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
vi.resetAllMocks();
|
||||
mockConfig = {
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getModel: () => 'test-model',
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
// Mock the methods that ChatRecordingService needs
|
||||
mockConfigInstance.getSessionId = vi
|
||||
.fn()
|
||||
.mockReturnValue('test-session-id');
|
||||
mockConfigInstance.getProjectRoot = vi
|
||||
.fn()
|
||||
.mockReturnValue('/test/project/root');
|
||||
mockConfigInstance.storage = {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
};
|
||||
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
|
||||
// Reset mocks before each test to ensure test isolation
|
||||
vi.mocked(mockModelsInstance.generateContent).mockReset();
|
||||
vi.mocked(mockModelsInstance.generateContentStream).mockReset();
|
||||
mockGeminiClient = new GeminiClient(mockConfig);
|
||||
|
||||
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
||||
chatInstance = new GeminiChat(
|
||||
mockConfigInstance,
|
||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||
mockConfig,
|
||||
{},
|
||||
[], // initial history
|
||||
);
|
||||
@@ -120,7 +75,7 @@ describe('checkNextSpeaker', () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should return null if history is empty', async () => {
|
||||
@@ -135,9 +90,9 @@ describe('checkNextSpeaker', () => {
|
||||
});
|
||||
|
||||
it('should return null if the last speaker was the user', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
vi.mocked(chatInstance.getHistory).mockReturnValue([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
] as Content[]);
|
||||
]);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
|
||||
Reference in New Issue
Block a user