diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts index 1b1787f5fd..ade925a181 100644 --- a/packages/core/src/core/baseLlmClient.test.ts +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -40,9 +40,11 @@ vi.mock('../utils/retry.js', () => ({ })); const mockGenerateContent = vi.fn(); +const mockEmbedContent = vi.fn(); const mockContentGenerator = { generateContent: mockGenerateContent, + embedContent: mockEmbedContent, } as unknown as Mocked; const mockConfig = { @@ -50,6 +52,7 @@ const mockConfig = { getContentGeneratorConfig: vi .fn() .mockReturnValue({ authType: AuthType.USE_GEMINI }), + getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'), } as unknown as Mocked; // Helper to create a mock GenerateContentResponse @@ -288,4 +291,93 @@ describe('BaseLlmClient', () => { expect(reportError).not.toHaveBeenCalled(); }); }); + + describe('generateEmbedding', () => { + const texts = ['hello world', 'goodbye world']; + const testEmbeddingModel = 'test-embedding-model'; + + it('should call embedContent with correct parameters and return embeddings', async () => { + const mockEmbeddings = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ]; + mockEmbedContent.mockResolvedValue({ + embeddings: [ + { values: mockEmbeddings[0] }, + { values: mockEmbeddings[1] }, + ], + }); + + const result = await client.generateEmbedding(texts); + + expect(mockEmbedContent).toHaveBeenCalledTimes(1); + expect(mockEmbedContent).toHaveBeenCalledWith({ + model: testEmbeddingModel, + contents: texts, + }); + expect(result).toEqual(mockEmbeddings); + }); + + it('should return an empty array if an empty array is passed', async () => { + const result = await client.generateEmbedding([]); + expect(result).toEqual([]); + expect(mockEmbedContent).not.toHaveBeenCalled(); + }); + + it('should throw an error if API response has no embeddings array', async () => { + mockEmbedContent.mockResolvedValue({}); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'No embeddings found in API response.', + ); + }); + + it('should throw an error if API response has an empty embeddings array', async () => { + mockEmbedContent.mockResolvedValue({ + embeddings: [], + }); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'No embeddings found in API response.', + ); + }); + + it('should throw an error if API returns a mismatched number of embeddings', async () => { + mockEmbedContent.mockResolvedValue({ + embeddings: [{ values: [1, 2, 3] }], // Only one for two texts + }); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned a mismatched number of embeddings. Expected 2, got 1.', + ); + }); + + it('should throw an error if any embedding has nullish values', async () => { + mockEmbedContent.mockResolvedValue({ + embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad + }); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned an empty embedding for input text at index 1: "goodbye world"', + ); + }); + + it('should throw an error if any embedding has an empty values array', async () => { + mockEmbedContent.mockResolvedValue({ + embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad + }); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned an empty embedding for input text at index 0: "hello world"', + ); + }); + + it('should propagate errors from the API call', async () => { + mockEmbedContent.mockRejectedValue(new Error('API Failure')); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API Failure', + ); + }); + }); }); diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts index 25a92dabdd..8ce63540fc 100644 --- a/packages/core/src/core/baseLlmClient.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -4,7 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Content, GenerateContentConfig, Part } from '@google/genai'; +import type { + Content, + GenerateContentConfig, + Part, + EmbedContentParameters, +} from '@google/genai'; import type { Config } from '../config/config.js'; import type { ContentGenerator } from './contentGenerator.js'; import { getResponseText } from '../utils/partUtils.js'; @@ -156,6 +161,41 @@ export class BaseLlmClient { } } + async generateEmbedding(texts: string[]): Promise { + if (!texts || texts.length === 0) { + return []; + } + const embedModelParams: EmbedContentParameters = { + model: this.config.getEmbeddingModel(), + contents: texts, + }; + + const embedContentResponse = + await this.contentGenerator.embedContent(embedModelParams); + if ( + !embedContentResponse.embeddings || + embedContentResponse.embeddings.length === 0 + ) { + throw new Error('No embeddings found in API response.'); + } + + if (embedContentResponse.embeddings.length !== texts.length) { + throw new Error( + `API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`, + ); + } + + return embedContentResponse.embeddings.map((embedding, index) => { + const values = embedding.values; + if (!values || values.length === 0) { + throw new Error( + `API returned an empty embedding for input text at index ${index}: "${texts[index]}"`, + ); + } + return values; + }); + } + private cleanJsonResponse(text: string, model: string): string { const prefix = '```json'; const suffix = '```'; diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 2b86b01718..37f002a527 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -237,7 +237,6 @@ describe('Gemini Client (client.ts)', () => { generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), - embedContent: vi.fn(), batchEmbedContents: vi.fn(), } as unknown as ContentGenerator; @@ -312,97 +311,6 @@ describe('Gemini Client (client.ts)', () => { vi.restoreAllMocks(); }); - describe('generateEmbedding', () => { - const texts = ['hello world', 'goodbye world']; - const testEmbeddingModel = 'test-embedding-model'; - - it('should call embedContent with correct parameters and return embeddings', async () => { - const mockEmbeddings = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - ]; - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ - embeddings: [ - { values: mockEmbeddings[0] }, - { values: mockEmbeddings[1] }, - ], - }); - - const result = await client.generateEmbedding(texts); - - expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1); - expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({ - model: testEmbeddingModel, - contents: texts, - }); - expect(result).toEqual(mockEmbeddings); - }); - - it('should return an empty array if an empty array is passed', async () => { - const result = await client.generateEmbedding([]); - expect(result).toEqual([]); - expect(mockContentGenerator.embedContent).not.toHaveBeenCalled(); - }); - - it('should throw an error if API response has no embeddings array', async () => { - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({}); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'No embeddings found in API response.', - ); - }); - - it('should throw an error if API response has an empty embeddings array', async () => { - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ - embeddings: [], - }); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'No embeddings found in API response.', - ); - }); - - it('should throw an error if API returns a mismatched number of embeddings', async () => { - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ - embeddings: [{ values: [1, 2, 3] }], // Only one for two texts - }); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'API returned a mismatched number of embeddings. Expected 2, got 1.', - ); - }); - - it('should throw an error if any embedding has nullish values', async () => { - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ - embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad - }); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'API returned an empty embedding for input text at index 1: "goodbye world"', - ); - }); - - it('should throw an error if any embedding has an empty values array', async () => { - vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ - embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad - }); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'API returned an empty embedding for input text at index 0: "hello world"', - ); - }); - - it('should propagate errors from the API call', async () => { - vi.mocked(mockContentGenerator.embedContent).mockRejectedValue( - new Error('API Failure'), - ); - - await expect(client.generateEmbedding(texts)).rejects.toThrow( - 'API Failure', - ); - }); - }); - describe('generateJson', () => { it('should call generateContent with the correct parameters', async () => { const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 25d7aa52fe..c070488fb7 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -5,7 +5,6 @@ */ import type { - EmbedContentParameters, GenerateContentConfig, PartListUnion, Content, @@ -758,41 +757,6 @@ export class GeminiClient { } } - async generateEmbedding(texts: string[]): Promise { - if (!texts || texts.length === 0) { - return []; - } - const embedModelParams: EmbedContentParameters = { - model: this.config.getEmbeddingModel(), - contents: texts, - }; - - const embedContentResponse = - await this.getContentGeneratorOrFail().embedContent(embedModelParams); - if ( - !embedContentResponse.embeddings || - embedContentResponse.embeddings.length === 0 - ) { - throw new Error('No embeddings found in API response.'); - } - - if (embedContentResponse.embeddings.length !== texts.length) { - throw new Error( - `API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`, - ); - } - - return embedContentResponse.embeddings.map((embedding, index) => { - const values = embedding.values; - if (!values || values.length === 0) { - throw new Error( - `API returned an empty embedding for input text at index ${index}: "${texts[index]}"`, - ); - } - return values; - }); - } - async tryCompressChat( prompt_id: string, force: boolean = false,