mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
refactor(core): Move generateEmbedding to BaseLlmClient (#8442)
This commit is contained in:
@@ -40,9 +40,11 @@ vi.mock('../utils/retry.js', () => ({
|
|||||||
}));
|
}));
|
||||||
|
|
||||||
const mockGenerateContent = vi.fn();
|
const mockGenerateContent = vi.fn();
|
||||||
|
const mockEmbedContent = vi.fn();
|
||||||
|
|
||||||
const mockContentGenerator = {
|
const mockContentGenerator = {
|
||||||
generateContent: mockGenerateContent,
|
generateContent: mockGenerateContent,
|
||||||
|
embedContent: mockEmbedContent,
|
||||||
} as unknown as Mocked<ContentGenerator>;
|
} as unknown as Mocked<ContentGenerator>;
|
||||||
|
|
||||||
const mockConfig = {
|
const mockConfig = {
|
||||||
@@ -50,6 +52,7 @@ const mockConfig = {
|
|||||||
getContentGeneratorConfig: vi
|
getContentGeneratorConfig: vi
|
||||||
.fn()
|
.fn()
|
||||||
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
|
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
|
||||||
|
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||||
} as unknown as Mocked<Config>;
|
} as unknown as Mocked<Config>;
|
||||||
|
|
||||||
// Helper to create a mock GenerateContentResponse
|
// Helper to create a mock GenerateContentResponse
|
||||||
@@ -288,4 +291,93 @@ describe('BaseLlmClient', () => {
|
|||||||
expect(reportError).not.toHaveBeenCalled();
|
expect(reportError).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('generateEmbedding', () => {
|
||||||
|
const texts = ['hello world', 'goodbye world'];
|
||||||
|
const testEmbeddingModel = 'test-embedding-model';
|
||||||
|
|
||||||
|
it('should call embedContent with correct parameters and return embeddings', async () => {
|
||||||
|
const mockEmbeddings = [
|
||||||
|
[0.1, 0.2, 0.3],
|
||||||
|
[0.4, 0.5, 0.6],
|
||||||
|
];
|
||||||
|
mockEmbedContent.mockResolvedValue({
|
||||||
|
embeddings: [
|
||||||
|
{ values: mockEmbeddings[0] },
|
||||||
|
{ values: mockEmbeddings[1] },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await client.generateEmbedding(texts);
|
||||||
|
|
||||||
|
expect(mockEmbedContent).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockEmbedContent).toHaveBeenCalledWith({
|
||||||
|
model: testEmbeddingModel,
|
||||||
|
contents: texts,
|
||||||
|
});
|
||||||
|
expect(result).toEqual(mockEmbeddings);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return an empty array if an empty array is passed', async () => {
|
||||||
|
const result = await client.generateEmbedding([]);
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
expect(mockEmbedContent).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API response has no embeddings array', async () => {
|
||||||
|
mockEmbedContent.mockResolvedValue({});
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'No embeddings found in API response.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||||
|
mockEmbedContent.mockResolvedValue({
|
||||||
|
embeddings: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'No embeddings found in API response.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||||
|
mockEmbedContent.mockResolvedValue({
|
||||||
|
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if any embedding has nullish values', async () => {
|
||||||
|
mockEmbedContent.mockResolvedValue({
|
||||||
|
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if any embedding has an empty values array', async () => {
|
||||||
|
mockEmbedContent.mockResolvedValue({
|
||||||
|
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should propagate errors from the API call', async () => {
|
||||||
|
mockEmbedContent.mockRejectedValue(new Error('API Failure'));
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API Failure',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -4,7 +4,12 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Content, GenerateContentConfig, Part } from '@google/genai';
|
import type {
|
||||||
|
Content,
|
||||||
|
GenerateContentConfig,
|
||||||
|
Part,
|
||||||
|
EmbedContentParameters,
|
||||||
|
} from '@google/genai';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import type { ContentGenerator } from './contentGenerator.js';
|
import type { ContentGenerator } from './contentGenerator.js';
|
||||||
import { getResponseText } from '../utils/partUtils.js';
|
import { getResponseText } from '../utils/partUtils.js';
|
||||||
@@ -156,6 +161,41 @@ export class BaseLlmClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
||||||
|
if (!texts || texts.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const embedModelParams: EmbedContentParameters = {
|
||||||
|
model: this.config.getEmbeddingModel(),
|
||||||
|
contents: texts,
|
||||||
|
};
|
||||||
|
|
||||||
|
const embedContentResponse =
|
||||||
|
await this.contentGenerator.embedContent(embedModelParams);
|
||||||
|
if (
|
||||||
|
!embedContentResponse.embeddings ||
|
||||||
|
embedContentResponse.embeddings.length === 0
|
||||||
|
) {
|
||||||
|
throw new Error('No embeddings found in API response.');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embedContentResponse.embeddings.length !== texts.length) {
|
||||||
|
throw new Error(
|
||||||
|
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return embedContentResponse.embeddings.map((embedding, index) => {
|
||||||
|
const values = embedding.values;
|
||||||
|
if (!values || values.length === 0) {
|
||||||
|
throw new Error(
|
||||||
|
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private cleanJsonResponse(text: string, model: string): string {
|
private cleanJsonResponse(text: string, model: string): string {
|
||||||
const prefix = '```json';
|
const prefix = '```json';
|
||||||
const suffix = '```';
|
const suffix = '```';
|
||||||
|
|||||||
@@ -237,7 +237,6 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
generateContent: mockGenerateContentFn,
|
generateContent: mockGenerateContentFn,
|
||||||
generateContentStream: vi.fn(),
|
generateContentStream: vi.fn(),
|
||||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||||
embedContent: vi.fn(),
|
|
||||||
batchEmbedContents: vi.fn(),
|
batchEmbedContents: vi.fn(),
|
||||||
} as unknown as ContentGenerator;
|
} as unknown as ContentGenerator;
|
||||||
|
|
||||||
@@ -312,97 +311,6 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('generateEmbedding', () => {
|
|
||||||
const texts = ['hello world', 'goodbye world'];
|
|
||||||
const testEmbeddingModel = 'test-embedding-model';
|
|
||||||
|
|
||||||
it('should call embedContent with correct parameters and return embeddings', async () => {
|
|
||||||
const mockEmbeddings = [
|
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6],
|
|
||||||
];
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
|
||||||
embeddings: [
|
|
||||||
{ values: mockEmbeddings[0] },
|
|
||||||
{ values: mockEmbeddings[1] },
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
const result = await client.generateEmbedding(texts);
|
|
||||||
|
|
||||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({
|
|
||||||
model: testEmbeddingModel,
|
|
||||||
contents: texts,
|
|
||||||
});
|
|
||||||
expect(result).toEqual(mockEmbeddings);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should return an empty array if an empty array is passed', async () => {
|
|
||||||
const result = await client.generateEmbedding([]);
|
|
||||||
expect(result).toEqual([]);
|
|
||||||
expect(mockContentGenerator.embedContent).not.toHaveBeenCalled();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should throw an error if API response has no embeddings array', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({});
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'No embeddings found in API response.',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should throw an error if API response has an empty embeddings array', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
|
||||||
embeddings: [],
|
|
||||||
});
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'No embeddings found in API response.',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
|
||||||
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
|
||||||
});
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should throw an error if any embedding has nullish values', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
|
||||||
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
|
||||||
});
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should throw an error if any embedding has an empty values array', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
|
|
||||||
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
|
||||||
});
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'API returned an empty embedding for input text at index 0: "hello world"',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should propagate errors from the API call', async () => {
|
|
||||||
vi.mocked(mockContentGenerator.embedContent).mockRejectedValue(
|
|
||||||
new Error('API Failure'),
|
|
||||||
);
|
|
||||||
|
|
||||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
|
||||||
'API Failure',
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('generateJson', () => {
|
describe('generateJson', () => {
|
||||||
it('should call generateContent with the correct parameters', async () => {
|
it('should call generateContent with the correct parameters', async () => {
|
||||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
EmbedContentParameters,
|
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
PartListUnion,
|
PartListUnion,
|
||||||
Content,
|
Content,
|
||||||
@@ -758,41 +757,6 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
|
||||||
if (!texts || texts.length === 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const embedModelParams: EmbedContentParameters = {
|
|
||||||
model: this.config.getEmbeddingModel(),
|
|
||||||
contents: texts,
|
|
||||||
};
|
|
||||||
|
|
||||||
const embedContentResponse =
|
|
||||||
await this.getContentGeneratorOrFail().embedContent(embedModelParams);
|
|
||||||
if (
|
|
||||||
!embedContentResponse.embeddings ||
|
|
||||||
embedContentResponse.embeddings.length === 0
|
|
||||||
) {
|
|
||||||
throw new Error('No embeddings found in API response.');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (embedContentResponse.embeddings.length !== texts.length) {
|
|
||||||
throw new Error(
|
|
||||||
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return embedContentResponse.embeddings.map((embedding, index) => {
|
|
||||||
const values = embedding.values;
|
|
||||||
if (!values || values.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return values;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async tryCompressChat(
|
async tryCompressChat(
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
force: boolean = false,
|
force: boolean = false,
|
||||||
|
|||||||
Reference in New Issue
Block a user