mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-16 00:51:25 -07:00
refactor(core): Move generateEmbedding to BaseLlmClient (#8442)
This commit is contained in:
@@ -40,9 +40,11 @@ vi.mock('../utils/retry.js', () => ({
|
||||
}));
|
||||
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockEmbedContent = vi.fn();
|
||||
|
||||
const mockContentGenerator = {
|
||||
generateContent: mockGenerateContent,
|
||||
embedContent: mockEmbedContent,
|
||||
} as unknown as Mocked<ContentGenerator>;
|
||||
|
||||
const mockConfig = {
|
||||
@@ -50,6 +52,7 @@ const mockConfig = {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
// Helper to create a mock GenerateContentResponse
|
||||
@@ -288,4 +291,93 @@ describe('BaseLlmClient', () => {
|
||||
expect(reportError).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateEmbedding', () => {
|
||||
const texts = ['hello world', 'goodbye world'];
|
||||
const testEmbeddingModel = 'test-embedding-model';
|
||||
|
||||
it('should call embedContent with correct parameters and return embeddings', async () => {
|
||||
const mockEmbeddings = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
];
|
||||
mockEmbedContent.mockResolvedValue({
|
||||
embeddings: [
|
||||
{ values: mockEmbeddings[0] },
|
||||
{ values: mockEmbeddings[1] },
|
||||
],
|
||||
});
|
||||
|
||||
const result = await client.generateEmbedding(texts);
|
||||
|
||||
expect(mockEmbedContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockEmbedContent).toHaveBeenCalledWith({
|
||||
model: testEmbeddingModel,
|
||||
contents: texts,
|
||||
});
|
||||
expect(result).toEqual(mockEmbeddings);
|
||||
});
|
||||
|
||||
it('should return an empty array if an empty array is passed', async () => {
|
||||
const result = await client.generateEmbedding([]);
|
||||
expect(result).toEqual([]);
|
||||
expect(mockEmbedContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if API response has no embeddings array', async () => {
|
||||
mockEmbedContent.mockResolvedValue({});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||
mockEmbedContent.mockResolvedValue({
|
||||
embeddings: [],
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||
mockEmbedContent.mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has nullish values', async () => {
|
||||
mockEmbedContent.mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has an empty values array', async () => {
|
||||
mockEmbedContent.mockResolvedValue({
|
||||
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should propagate errors from the API call', async () => {
|
||||
mockEmbedContent.mockRejectedValue(new Error('API Failure'));
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API Failure',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, GenerateContentConfig, Part } from '@google/genai';
|
||||
import type {
|
||||
Content,
|
||||
GenerateContentConfig,
|
||||
Part,
|
||||
EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
@@ -156,6 +161,41 @@ export class BaseLlmClient {
|
||||
}
|
||||
}
|
||||
|
||||
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
||||
if (!texts || texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const embedModelParams: EmbedContentParameters = {
|
||||
model: this.config.getEmbeddingModel(),
|
||||
contents: texts,
|
||||
};
|
||||
|
||||
const embedContentResponse =
|
||||
await this.contentGenerator.embedContent(embedModelParams);
|
||||
if (
|
||||
!embedContentResponse.embeddings ||
|
||||
embedContentResponse.embeddings.length === 0
|
||||
) {
|
||||
throw new Error('No embeddings found in API response.');
|
||||
}
|
||||
|
||||
if (embedContentResponse.embeddings.length !== texts.length) {
|
||||
throw new Error(
|
||||
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return embedContentResponse.embeddings.map((embedding, index) => {
|
||||
const values = embedding.values;
|
||||
if (!values || values.length === 0) {
|
||||
throw new Error(
|
||||
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
||||
);
|
||||
}
|
||||
return values;
|
||||
});
|
||||
}
|
||||
|
||||
private cleanJsonResponse(text: string, model: string): string {
|
||||
const prefix = '```json';
|
||||
const suffix = '```';
|
||||
|
||||
@@ -237,7 +237,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
@@ -312,97 +311,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('generateEmbedding', () => {
|
||||
const texts = ['hello world', 'goodbye world'];
|
||||
const testEmbeddingModel = 'test-embedding-model';
|
||||
|
||||
it('should call embedContent with correct parameters and return embeddings', async () => {
|
||||
const mockEmbeddings = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
];
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [
|
||||
{ values: mockEmbeddings[0] },
|
||||
{ values: mockEmbeddings[1] },
|
||||
],
|
||||
});
|
||||
|
||||
const result = await client.generateEmbedding(texts);
|
||||
|
||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({
|
||||
model: testEmbeddingModel,
|
||||
contents: texts,
|
||||
});
|
||||
expect(result).toEqual(mockEmbeddings);
|
||||
});
|
||||
|
||||
it('should return an empty array if an empty array is passed', async () => {
|
||||
const result = await client.generateEmbedding([]);
|
||||
expect(result).toEqual([]);
|
||||
expect(mockContentGenerator.embedContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if API response has no embeddings array', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [],
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has nullish values', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has an empty values array', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
||||
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||
});
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should propagate errors from the API call', async () => {
|
||||
vi.mocked(mockContentGenerator.embedContent).mockRejectedValue(
|
||||
new Error('API Failure'),
|
||||
);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API Failure',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateJson', () => {
|
||||
it('should call generateContent with the correct parameters', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
*/
|
||||
|
||||
import type {
|
||||
EmbedContentParameters,
|
||||
GenerateContentConfig,
|
||||
PartListUnion,
|
||||
Content,
|
||||
@@ -758,41 +757,6 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
||||
if (!texts || texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const embedModelParams: EmbedContentParameters = {
|
||||
model: this.config.getEmbeddingModel(),
|
||||
contents: texts,
|
||||
};
|
||||
|
||||
const embedContentResponse =
|
||||
await this.getContentGeneratorOrFail().embedContent(embedModelParams);
|
||||
if (
|
||||
!embedContentResponse.embeddings ||
|
||||
embedContentResponse.embeddings.length === 0
|
||||
) {
|
||||
throw new Error('No embeddings found in API response.');
|
||||
}
|
||||
|
||||
if (embedContentResponse.embeddings.length !== texts.length) {
|
||||
throw new Error(
|
||||
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return embedContentResponse.embeddings.map((embedding, index) => {
|
||||
const values = embedding.values;
|
||||
if (!values || values.length === 0) {
|
||||
throw new Error(
|
||||
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
||||
);
|
||||
}
|
||||
return values;
|
||||
});
|
||||
}
|
||||
|
||||
async tryCompressChat(
|
||||
prompt_id: string,
|
||||
force: boolean = false,
|
||||
|
||||
Reference in New Issue
Block a user