refactor(core): Move generateEmbedding to BaseLlmClient (#8442)

This commit is contained in:
Abhi
2025-09-15 11:33:30 -04:00
committed by GitHub
parent 5cbab75c7d
commit b5af662477
4 changed files with 133 additions and 129 deletions

View File

@@ -40,9 +40,11 @@ vi.mock('../utils/retry.js', () => ({
}));
const mockGenerateContent = vi.fn();
const mockEmbedContent = vi.fn();
const mockContentGenerator = {
generateContent: mockGenerateContent,
embedContent: mockEmbedContent,
} as unknown as Mocked<ContentGenerator>;
const mockConfig = {
@@ -50,6 +52,7 @@ const mockConfig = {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
} as unknown as Mocked<Config>;
// Helper to create a mock GenerateContentResponse
@@ -288,4 +291,93 @@ describe('BaseLlmClient', () => {
expect(reportError).not.toHaveBeenCalled();
});
});
describe('generateEmbedding', () => {
const texts = ['hello world', 'goodbye world'];
const testEmbeddingModel = 'test-embedding-model';
it('should call embedContent with correct parameters and return embeddings', async () => {
const mockEmbeddings = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
mockEmbedContent.mockResolvedValue({
embeddings: [
{ values: mockEmbeddings[0] },
{ values: mockEmbeddings[1] },
],
});
const result = await client.generateEmbedding(texts);
expect(mockEmbedContent).toHaveBeenCalledTimes(1);
expect(mockEmbedContent).toHaveBeenCalledWith({
model: testEmbeddingModel,
contents: texts,
});
expect(result).toEqual(mockEmbeddings);
});
it('should return an empty array if an empty array is passed', async () => {
const result = await client.generateEmbedding([]);
expect(result).toEqual([]);
expect(mockEmbedContent).not.toHaveBeenCalled();
});
it('should throw an error if API response has no embeddings array', async () => {
mockEmbedContent.mockResolvedValue({});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API response has an empty embeddings array', async () => {
mockEmbedContent.mockResolvedValue({
embeddings: [],
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API returns a mismatched number of embeddings', async () => {
mockEmbedContent.mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned a mismatched number of embeddings. Expected 2, got 1.',
);
});
it('should throw an error if any embedding has nullish values', async () => {
mockEmbedContent.mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 1: "goodbye world"',
);
});
it('should throw an error if any embedding has an empty values array', async () => {
mockEmbedContent.mockResolvedValue({
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 0: "hello world"',
);
});
it('should propagate errors from the API call', async () => {
mockEmbedContent.mockRejectedValue(new Error('API Failure'));
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API Failure',
);
});
});
});

View File

@@ -4,7 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, GenerateContentConfig, Part } from '@google/genai';
import type {
Content,
GenerateContentConfig,
Part,
EmbedContentParameters,
} from '@google/genai';
import type { Config } from '../config/config.js';
import type { ContentGenerator } from './contentGenerator.js';
import { getResponseText } from '../utils/partUtils.js';
@@ -156,6 +161,41 @@ export class BaseLlmClient {
}
}
async generateEmbedding(texts: string[]): Promise<number[][]> {
if (!texts || texts.length === 0) {
return [];
}
const embedModelParams: EmbedContentParameters = {
model: this.config.getEmbeddingModel(),
contents: texts,
};
const embedContentResponse =
await this.contentGenerator.embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
) {
throw new Error('No embeddings found in API response.');
}
if (embedContentResponse.embeddings.length !== texts.length) {
throw new Error(
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
);
}
return embedContentResponse.embeddings.map((embedding, index) => {
const values = embedding.values;
if (!values || values.length === 0) {
throw new Error(
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
);
}
return values;
});
}
private cleanJsonResponse(text: string, model: string): string {
const prefix = '```json';
const suffix = '```';

View File

@@ -237,7 +237,6 @@ describe('Gemini Client (client.ts)', () => {
generateContent: mockGenerateContentFn,
generateContentStream: vi.fn(),
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as ContentGenerator;
@@ -312,97 +311,6 @@ describe('Gemini Client (client.ts)', () => {
vi.restoreAllMocks();
});
describe('generateEmbedding', () => {
const texts = ['hello world', 'goodbye world'];
const testEmbeddingModel = 'test-embedding-model';
it('should call embedContent with correct parameters and return embeddings', async () => {
const mockEmbeddings = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [
{ values: mockEmbeddings[0] },
{ values: mockEmbeddings[1] },
],
});
const result = await client.generateEmbedding(texts);
expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({
model: testEmbeddingModel,
contents: texts,
});
expect(result).toEqual(mockEmbeddings);
});
it('should return an empty array if an empty array is passed', async () => {
const result = await client.generateEmbedding([]);
expect(result).toEqual([]);
expect(mockContentGenerator.embedContent).not.toHaveBeenCalled();
});
it('should throw an error if API response has no embeddings array', async () => {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API response has an empty embeddings array', async () => {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [],
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API returns a mismatched number of embeddings', async () => {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned a mismatched number of embeddings. Expected 2, got 1.',
);
});
it('should throw an error if any embedding has nullish values', async () => {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 1: "goodbye world"',
);
});
it('should throw an error if any embedding has an empty values array', async () => {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 0: "hello world"',
);
});
it('should propagate errors from the API call', async () => {
vi.mocked(mockContentGenerator.embedContent).mockRejectedValue(
new Error('API Failure'),
);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API Failure',
);
});
});
describe('generateJson', () => {
it('should call generateContent with the correct parameters', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];

View File

@@ -5,7 +5,6 @@
*/
import type {
EmbedContentParameters,
GenerateContentConfig,
PartListUnion,
Content,
@@ -758,41 +757,6 @@ export class GeminiClient {
}
}
async generateEmbedding(texts: string[]): Promise<number[][]> {
if (!texts || texts.length === 0) {
return [];
}
const embedModelParams: EmbedContentParameters = {
model: this.config.getEmbeddingModel(),
contents: texts,
};
const embedContentResponse =
await this.getContentGeneratorOrFail().embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
) {
throw new Error('No embeddings found in API response.');
}
if (embedContentResponse.embeddings.length !== texts.length) {
throw new Error(
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
);
}
return embedContentResponse.embeddings.map((embedding, index) => {
const values = embedding.values;
if (!values || values.length === 0) {
throw new Error(
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
);
}
return values;
});
}
async tryCompressChat(
prompt_id: string,
force: boolean = false,