mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
feat(core): Late resolve GenerateContentConfigs and reduce mutation. (#14920)
This commit is contained in:
@@ -132,11 +132,15 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
it('returns requested model if it is available', () => {
|
it('returns requested model if it is available', () => {
|
||||||
const config = createExtendedMockConfig();
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: {},
|
||||||
|
});
|
||||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
selectedModel: 'gemini-pro',
|
selectedModel: 'gemini-pro',
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = applyModelSelection(config, 'gemini-pro');
|
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||||
expect(result.model).toBe('gemini-pro');
|
expect(result.model).toBe('gemini-pro');
|
||||||
expect(result.maxAttempts).toBeUndefined();
|
expect(result.maxAttempts).toBeUndefined();
|
||||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||||
@@ -144,15 +148,20 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
it('switches to backup model and updates config if requested is unavailable', () => {
|
it('switches to backup model and updates config if requested is unavailable', () => {
|
||||||
const config = createExtendedMockConfig();
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig
|
||||||
|
.mockReturnValueOnce({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: { temperature: 0.9, topP: 1 },
|
||||||
|
})
|
||||||
|
.mockReturnValueOnce({
|
||||||
|
model: 'gemini-flash',
|
||||||
|
generateContentConfig: { temperature: 0.1, topP: 1 },
|
||||||
|
});
|
||||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
selectedModel: 'gemini-flash',
|
selectedModel: 'gemini-flash',
|
||||||
});
|
});
|
||||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
|
||||||
generateContentConfig: { temperature: 0.1 },
|
|
||||||
});
|
|
||||||
|
|
||||||
const currentConfig = { temperature: 0.9, topP: 1 };
|
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||||
const result = applyModelSelection(config, 'gemini-pro', currentConfig);
|
|
||||||
|
|
||||||
expect(result.model).toBe('gemini-flash');
|
expect(result.model).toBe('gemini-flash');
|
||||||
expect(result.config).toEqual({
|
expect(result.config).toEqual({
|
||||||
@@ -160,6 +169,9 @@ describe('policyHelpers', () => {
|
|||||||
topP: 1,
|
topP: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
});
|
||||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||||
model: 'gemini-flash',
|
model: 'gemini-flash',
|
||||||
});
|
});
|
||||||
@@ -168,12 +180,16 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
it('consumes sticky attempt if indicated', () => {
|
it('consumes sticky attempt if indicated', () => {
|
||||||
const config = createExtendedMockConfig();
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: {},
|
||||||
|
});
|
||||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
selectedModel: 'gemini-pro',
|
selectedModel: 'gemini-pro',
|
||||||
attempts: 1,
|
attempts: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = applyModelSelection(config, 'gemini-pro');
|
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||||
'gemini-pro',
|
'gemini-pro',
|
||||||
);
|
);
|
||||||
@@ -182,6 +198,10 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
it('does not consume sticky attempt if consumeAttempt is false', () => {
|
it('does not consume sticky attempt if consumeAttempt is false', () => {
|
||||||
const config = createExtendedMockConfig();
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: {},
|
||||||
|
});
|
||||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
selectedModel: 'gemini-pro',
|
selectedModel: 'gemini-pro',
|
||||||
attempts: 1,
|
attempts: 1,
|
||||||
@@ -189,9 +209,7 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
const result = applyModelSelection(
|
const result = applyModelSelection(
|
||||||
config,
|
config,
|
||||||
'gemini-pro',
|
{ model: 'gemini-pro' },
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
{
|
{
|
||||||
consumeAttempt: false,
|
consumeAttempt: false,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import {
|
|||||||
resolveModel,
|
resolveModel,
|
||||||
} from '../config/models.js';
|
} from '../config/models.js';
|
||||||
import type { ModelSelectionResult } from './modelAvailabilityService.js';
|
import type { ModelSelectionResult } from './modelAvailabilityService.js';
|
||||||
|
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Resolves the active policy chain for the given config, ensuring the
|
* Resolves the active policy chain for the given config, ensuring the
|
||||||
@@ -155,31 +156,26 @@ export function selectModelForAvailability(
|
|||||||
*/
|
*/
|
||||||
export function applyModelSelection(
|
export function applyModelSelection(
|
||||||
config: Config,
|
config: Config,
|
||||||
requestedModel: string,
|
modelConfigKey: ModelConfigKey,
|
||||||
currentConfig?: GenerateContentConfig,
|
|
||||||
overrideScope?: string,
|
|
||||||
options: { consumeAttempt?: boolean } = {},
|
options: { consumeAttempt?: boolean } = {},
|
||||||
): { model: string; config?: GenerateContentConfig; maxAttempts?: number } {
|
): { model: string; config: GenerateContentConfig; maxAttempts?: number } {
|
||||||
const selection = selectModelForAvailability(config, requestedModel);
|
const resolved = config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||||
|
const model = resolved.model;
|
||||||
|
const selection = selectModelForAvailability(config, model);
|
||||||
|
|
||||||
if (!selection?.selectedModel) {
|
if (!selection) {
|
||||||
return { model: requestedModel, config: currentConfig };
|
return { model, config: resolved.generateContentConfig };
|
||||||
}
|
}
|
||||||
|
|
||||||
const finalModel = selection.selectedModel;
|
const finalModel = selection.selectedModel ?? model;
|
||||||
let finalConfig = currentConfig;
|
let generateContentConfig = resolved.generateContentConfig;
|
||||||
|
|
||||||
// If model changed, re-resolve config
|
if (finalModel !== model) {
|
||||||
if (finalModel !== requestedModel) {
|
const fallbackResolved = config.modelConfigService.getResolvedConfig({
|
||||||
const { generateContentConfig } =
|
...modelConfigKey,
|
||||||
config.modelConfigService.getResolvedConfig({
|
model: finalModel,
|
||||||
overrideScope,
|
});
|
||||||
model: finalModel,
|
generateContentConfig = fallbackResolved.generateContentConfig;
|
||||||
});
|
|
||||||
|
|
||||||
finalConfig = currentConfig
|
|
||||||
? { ...currentConfig, ...generateContentConfig }
|
|
||||||
: generateContentConfig;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config.setActiveModel(finalModel);
|
config.setActiveModel(finalModel);
|
||||||
@@ -190,7 +186,7 @@ export function applyModelSelection(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
model: finalModel,
|
model: finalModel,
|
||||||
config: finalConfig,
|
config: generateContentConfig,
|
||||||
maxAttempts: selection.attempts,
|
maxAttempts: selection.attempts,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -776,13 +776,15 @@ describe('BaseLlmClient', () => {
|
|||||||
const getResolvedConfigMock = vi.mocked(
|
const getResolvedConfigMock = vi.mocked(
|
||||||
mockConfig.modelConfigService.getResolvedConfig,
|
mockConfig.modelConfigService.getResolvedConfig,
|
||||||
);
|
);
|
||||||
getResolvedConfigMock
|
getResolvedConfigMock.mockImplementation((key) => {
|
||||||
.mockReturnValueOnce(
|
if (key.model === firstModel) {
|
||||||
makeResolvedModelConfig(firstModel, { temperature: 0.1 }),
|
return makeResolvedModelConfig(firstModel, { temperature: 0.1 });
|
||||||
)
|
}
|
||||||
.mockReturnValueOnce(
|
if (key.model === fallbackModel) {
|
||||||
makeResolvedModelConfig(fallbackModel, { temperature: 0.9 }),
|
return makeResolvedModelConfig(fallbackModel, { temperature: 0.9 });
|
||||||
);
|
}
|
||||||
|
return makeResolvedModelConfig(key.model);
|
||||||
|
});
|
||||||
|
|
||||||
// Availability selects the first model initially
|
// Availability selects the first model initially
|
||||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import type {
|
|||||||
EmbedContentParameters,
|
EmbedContentParameters,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GenerateContentParameters,
|
GenerateContentParameters,
|
||||||
|
GenerateContentConfig,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import type { ContentGenerator } from './contentGenerator.js';
|
import type { ContentGenerator } from './contentGenerator.js';
|
||||||
@@ -81,6 +82,19 @@ export interface GenerateContentOptions {
|
|||||||
maxAttempts?: number;
|
maxAttempts?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface _CommonGenerateOptions {
|
||||||
|
modelConfigKey: ModelConfigKey;
|
||||||
|
contents: Content[];
|
||||||
|
systemInstruction?: string | Part | Part[] | Content;
|
||||||
|
abortSignal: AbortSignal;
|
||||||
|
promptId: string;
|
||||||
|
maxAttempts?: number;
|
||||||
|
additionalProperties?: {
|
||||||
|
responseJsonSchema: Record<string, unknown>;
|
||||||
|
responseMimeType: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A client dedicated to stateless, utility-focused LLM calls.
|
* A client dedicated to stateless, utility-focused LLM calls.
|
||||||
*/
|
*/
|
||||||
@@ -104,7 +118,7 @@ export class BaseLlmClient {
|
|||||||
maxAttempts,
|
maxAttempts,
|
||||||
} = options;
|
} = options;
|
||||||
|
|
||||||
const { model, generateContentConfig } =
|
const { model } =
|
||||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||||
|
|
||||||
const shouldRetryOnContent = (response: GenerateContentResponse) => {
|
const shouldRetryOnContent = (response: GenerateContentResponse) => {
|
||||||
@@ -123,18 +137,17 @@ export class BaseLlmClient {
|
|||||||
|
|
||||||
const result = await this._generateWithRetry(
|
const result = await this._generateWithRetry(
|
||||||
{
|
{
|
||||||
model,
|
modelConfigKey,
|
||||||
contents,
|
contents,
|
||||||
config: {
|
abortSignal,
|
||||||
...generateContentConfig,
|
promptId,
|
||||||
...(systemInstruction && { systemInstruction }),
|
maxAttempts,
|
||||||
|
systemInstruction,
|
||||||
|
additionalProperties: {
|
||||||
responseJsonSchema: schema,
|
responseJsonSchema: schema,
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
abortSignal,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
promptId,
|
|
||||||
maxAttempts,
|
|
||||||
shouldRetryOnContent,
|
shouldRetryOnContent,
|
||||||
'generateJson',
|
'generateJson',
|
||||||
);
|
);
|
||||||
@@ -205,9 +218,6 @@ export class BaseLlmClient {
|
|||||||
maxAttempts,
|
maxAttempts,
|
||||||
} = options;
|
} = options;
|
||||||
|
|
||||||
const { model, generateContentConfig } =
|
|
||||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
|
||||||
|
|
||||||
const shouldRetryOnContent = (response: GenerateContentResponse) => {
|
const shouldRetryOnContent = (response: GenerateContentResponse) => {
|
||||||
const text = getResponseText(response)?.trim();
|
const text = getResponseText(response)?.trim();
|
||||||
return !text; // Retry on empty response
|
return !text; // Retry on empty response
|
||||||
@@ -215,70 +225,74 @@ export class BaseLlmClient {
|
|||||||
|
|
||||||
return this._generateWithRetry(
|
return this._generateWithRetry(
|
||||||
{
|
{
|
||||||
model,
|
modelConfigKey,
|
||||||
contents,
|
contents,
|
||||||
config: {
|
systemInstruction,
|
||||||
...generateContentConfig,
|
abortSignal,
|
||||||
...(systemInstruction && { systemInstruction }),
|
promptId,
|
||||||
abortSignal,
|
maxAttempts,
|
||||||
},
|
|
||||||
},
|
},
|
||||||
promptId,
|
|
||||||
maxAttempts,
|
|
||||||
shouldRetryOnContent,
|
shouldRetryOnContent,
|
||||||
'generateContent',
|
'generateContent',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async _generateWithRetry(
|
private async _generateWithRetry(
|
||||||
requestParams: GenerateContentParameters,
|
options: _CommonGenerateOptions,
|
||||||
promptId: string,
|
|
||||||
maxAttempts: number | undefined,
|
|
||||||
shouldRetryOnContent: (response: GenerateContentResponse) => boolean,
|
shouldRetryOnContent: (response: GenerateContentResponse) => boolean,
|
||||||
errorContext: 'generateJson' | 'generateContent',
|
errorContext: 'generateJson' | 'generateContent',
|
||||||
): Promise<GenerateContentResponse> {
|
): Promise<GenerateContentResponse> {
|
||||||
const abortSignal = requestParams.config?.abortSignal;
|
const {
|
||||||
|
modelConfigKey,
|
||||||
|
contents,
|
||||||
|
systemInstruction,
|
||||||
|
abortSignal,
|
||||||
|
promptId,
|
||||||
|
maxAttempts,
|
||||||
|
additionalProperties,
|
||||||
|
} = options;
|
||||||
|
|
||||||
|
const {
|
||||||
|
model,
|
||||||
|
config: generateContentConfig,
|
||||||
|
maxAttempts: availabilityMaxAttempts,
|
||||||
|
} = applyModelSelection(this.config, modelConfigKey);
|
||||||
|
|
||||||
|
let currentModel = model;
|
||||||
|
let currentGenerateContentConfig = generateContentConfig;
|
||||||
|
|
||||||
// Define callback to fetch context dynamically since active model may get updated during retry loop
|
// Define callback to fetch context dynamically since active model may get updated during retry loop
|
||||||
const getAvailabilityContext = createAvailabilityContextProvider(
|
const getAvailabilityContext = createAvailabilityContextProvider(
|
||||||
this.config,
|
this.config,
|
||||||
() => requestParams.model,
|
() => currentModel,
|
||||||
);
|
);
|
||||||
|
|
||||||
const {
|
|
||||||
model,
|
|
||||||
config: newConfig,
|
|
||||||
maxAttempts: availabilityMaxAttempts,
|
|
||||||
} = applyModelSelection(
|
|
||||||
this.config,
|
|
||||||
requestParams.model,
|
|
||||||
requestParams.config,
|
|
||||||
);
|
|
||||||
requestParams.model = model;
|
|
||||||
if (newConfig) {
|
|
||||||
requestParams.config = newConfig;
|
|
||||||
}
|
|
||||||
if (abortSignal) {
|
|
||||||
requestParams.config = { ...requestParams.config, abortSignal };
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () => {
|
const apiCall = () => {
|
||||||
// Ensure we use the current active model
|
// Ensure we use the current active model
|
||||||
// in case a fallback occurred in a previous attempt.
|
// in case a fallback occurred in a previous attempt.
|
||||||
const activeModel = this.config.getActiveModel();
|
const activeModel = this.config.getActiveModel();
|
||||||
if (activeModel !== requestParams.model) {
|
if (activeModel !== currentModel) {
|
||||||
requestParams.model = activeModel;
|
currentModel = activeModel;
|
||||||
// Re-resolve config if model changed during retry
|
// Re-resolve config if model changed during retry
|
||||||
const { generateContentConfig } =
|
const { generateContentConfig } =
|
||||||
this.config.modelConfigService.getResolvedConfig({
|
this.config.modelConfigService.getResolvedConfig({
|
||||||
|
...modelConfigKey,
|
||||||
model: activeModel,
|
model: activeModel,
|
||||||
});
|
});
|
||||||
requestParams.config = {
|
currentGenerateContentConfig = generateContentConfig;
|
||||||
...requestParams.config,
|
|
||||||
...generateContentConfig,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
const finalConfig: GenerateContentConfig = {
|
||||||
|
...currentGenerateContentConfig,
|
||||||
|
...(systemInstruction && { systemInstruction }),
|
||||||
|
...additionalProperties,
|
||||||
|
abortSignal,
|
||||||
|
};
|
||||||
|
const requestParams: GenerateContentParameters = {
|
||||||
|
model: currentModel,
|
||||||
|
config: finalConfig,
|
||||||
|
contents,
|
||||||
|
};
|
||||||
return this.contentGenerator.generateContent(requestParams, promptId);
|
return this.contentGenerator.generateContent(requestParams, promptId);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -289,7 +303,7 @@ export class BaseLlmClient {
|
|||||||
getAvailabilityContext,
|
getAvailabilityContext,
|
||||||
onPersistent429: this.config.isInteractive()
|
onPersistent429: this.config.isInteractive()
|
||||||
? (authType, error) =>
|
? (authType, error) =>
|
||||||
handleFallback(this.config, requestParams.model, authType, error)
|
handleFallback(this.config, currentModel, authType, error)
|
||||||
: undefined,
|
: undefined,
|
||||||
authType:
|
authType:
|
||||||
this.authType ?? this.config.getContentGeneratorConfig()?.authType,
|
this.authType ?? this.config.getContentGeneratorConfig()?.authType,
|
||||||
@@ -307,14 +321,14 @@ export class BaseLlmClient {
|
|||||||
await reportError(
|
await reportError(
|
||||||
error,
|
error,
|
||||||
`API returned invalid content after all retries.`,
|
`API returned invalid content after all retries.`,
|
||||||
requestParams.contents as Content[],
|
contents,
|
||||||
`${errorContext}-invalid-content`,
|
`${errorContext}-invalid-content`,
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
await reportError(
|
await reportError(
|
||||||
error,
|
error,
|
||||||
`Error generating content via API.`,
|
`Error generating content via API.`,
|
||||||
requestParams.contents as Content[],
|
contents,
|
||||||
`${errorContext}-api`,
|
`${errorContext}-api`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -539,11 +539,10 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// availability logic
|
// availability logic
|
||||||
|
const modelConfigKey: ModelConfigKey = { model: modelToUse };
|
||||||
const { model: finalModel } = applyModelSelection(
|
const { model: finalModel } = applyModelSelection(
|
||||||
this.config,
|
this.config,
|
||||||
modelToUse,
|
modelConfigKey,
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
{ consumeAttempt: false },
|
{ consumeAttempt: false },
|
||||||
);
|
);
|
||||||
modelToUse = finalModel;
|
modelToUse = finalModel;
|
||||||
@@ -551,7 +550,7 @@ export class GeminiClient {
|
|||||||
this.currentSequenceModel = modelToUse;
|
this.currentSequenceModel = modelToUse;
|
||||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||||
|
|
||||||
const resultStream = turn.run({ model: modelToUse }, request, linkedSignal);
|
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
|
||||||
for await (const event of resultStream) {
|
for await (const event of resultStream) {
|
||||||
if (this.loopDetector.addAndCheck(event)) {
|
if (this.loopDetector.addAndCheck(event)) {
|
||||||
yield { type: GeminiEventType.LoopDetected };
|
yield { type: GeminiEventType.LoopDetected };
|
||||||
@@ -676,12 +675,7 @@ export class GeminiClient {
|
|||||||
model,
|
model,
|
||||||
config: newConfig,
|
config: newConfig,
|
||||||
maxAttempts: availabilityMaxAttempts,
|
maxAttempts: availabilityMaxAttempts,
|
||||||
} = applyModelSelection(
|
} = applyModelSelection(this.config, modelConfigKey);
|
||||||
this.config,
|
|
||||||
currentAttemptModel,
|
|
||||||
currentAttemptGenerateContentConfig,
|
|
||||||
modelConfigKey.overrideScope,
|
|
||||||
);
|
|
||||||
currentAttemptModel = model;
|
currentAttemptModel = model;
|
||||||
if (newConfig) {
|
if (newConfig) {
|
||||||
currentAttemptGenerateContentConfig = newConfig;
|
currentAttemptGenerateContentConfig = newConfig;
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ describe('GeminiChat', () => {
|
|||||||
return {
|
return {
|
||||||
model,
|
model,
|
||||||
generateContentConfig: {
|
generateContentConfig: {
|
||||||
temperature: 0,
|
temperature: modelConfigKey.isRetry ? 1 : 0,
|
||||||
thinkingConfig,
|
thinkingConfig,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@@ -2332,13 +2332,18 @@ describe('GeminiChat', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Different configs per model
|
// Different configs per model
|
||||||
vi.mocked(mockConfig.modelConfigService.getResolvedConfig)
|
vi.mocked(
|
||||||
.mockReturnValueOnce(
|
mockConfig.modelConfigService.getResolvedConfig,
|
||||||
makeResolvedModelConfig('model-a', { temperature: 0.1 }),
|
).mockImplementation((key) => {
|
||||||
)
|
if (key.model === 'model-a') {
|
||||||
.mockReturnValueOnce(
|
return makeResolvedModelConfig('model-a', { temperature: 0.1 });
|
||||||
makeResolvedModelConfig('model-b', { temperature: 0.9 }),
|
}
|
||||||
);
|
if (key.model === 'model-b') {
|
||||||
|
return makeResolvedModelConfig('model-b', { temperature: 0.9 });
|
||||||
|
}
|
||||||
|
// Default for the initial requested model in this test
|
||||||
|
return makeResolvedModelConfig('model-a', { temperature: 0.1 });
|
||||||
|
});
|
||||||
|
|
||||||
// First attempt uses model-a, then simulate availability switching to model-b
|
// First attempt uses model-a, then simulate availability switching to model-b
|
||||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
|
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
|
||||||
|
|||||||
@@ -16,13 +16,11 @@ import type {
|
|||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GenerateContentParameters,
|
GenerateContentParameters,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { ThinkingLevel } from '@google/genai';
|
|
||||||
import { toParts } from '../code_assist/converter.js';
|
import { toParts } from '../code_assist/converter.js';
|
||||||
import { createUserContent, FinishReason } from '@google/genai';
|
import { createUserContent, FinishReason } from '@google/genai';
|
||||||
import { retryWithBackoff, isRetryableError } from '../utils/retry.js';
|
import { retryWithBackoff, isRetryableError } from '../utils/retry.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import {
|
import {
|
||||||
DEFAULT_THINKING_MODE,
|
|
||||||
resolveModel,
|
resolveModel,
|
||||||
isGemini2Model,
|
isGemini2Model,
|
||||||
isPreviewModel,
|
isPreviewModel,
|
||||||
@@ -276,9 +274,8 @@ export class GeminiChat {
|
|||||||
this.sendPromise = streamDonePromise;
|
this.sendPromise = streamDonePromise;
|
||||||
|
|
||||||
const userContent = createUserContent(message);
|
const userContent = createUserContent(message);
|
||||||
const { model, generateContentConfig } =
|
const { model } =
|
||||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||||
generateContentConfig.abortSignal = signal;
|
|
||||||
|
|
||||||
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
||||||
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
||||||
@@ -316,17 +313,18 @@ export class GeminiChat {
|
|||||||
yield { type: StreamEventType.RETRY };
|
yield { type: StreamEventType.RETRY };
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is a retry, set temperature to 1 to encourage different output.
|
// If this is a retry, update the key with the new context.
|
||||||
if (attempt > 0) {
|
const currentConfigKey =
|
||||||
generateContentConfig.temperature = 1;
|
attempt > 0
|
||||||
}
|
? { ...modelConfigKey, isRetry: true }
|
||||||
|
: modelConfigKey;
|
||||||
|
|
||||||
isConnectionPhase = true;
|
isConnectionPhase = true;
|
||||||
const stream = await this.makeApiCallAndProcessStream(
|
const stream = await this.makeApiCallAndProcessStream(
|
||||||
model,
|
currentConfigKey,
|
||||||
generateContentConfig,
|
|
||||||
requestContents,
|
requestContents,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
|
signal,
|
||||||
);
|
);
|
||||||
isConnectionPhase = false;
|
isConnectionPhase = false;
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
@@ -399,10 +397,10 @@ export class GeminiChat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async makeApiCallAndProcessStream(
|
private async makeApiCallAndProcessStream(
|
||||||
model: string,
|
modelConfigKey: ModelConfigKey,
|
||||||
generateContentConfig: GenerateContentConfig,
|
|
||||||
requestContents: Content[],
|
requestContents: Content[],
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||||
const contentsForPreviewModel =
|
const contentsForPreviewModel =
|
||||||
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
||||||
@@ -412,18 +410,11 @@ export class GeminiChat {
|
|||||||
model: availabilityFinalModel,
|
model: availabilityFinalModel,
|
||||||
config: newAvailabilityConfig,
|
config: newAvailabilityConfig,
|
||||||
maxAttempts: availabilityMaxAttempts,
|
maxAttempts: availabilityMaxAttempts,
|
||||||
} = applyModelSelection(this.config, model, generateContentConfig);
|
} = applyModelSelection(this.config, modelConfigKey);
|
||||||
|
|
||||||
const abortSignal = generateContentConfig.abortSignal;
|
|
||||||
let lastModelToUse = availabilityFinalModel;
|
let lastModelToUse = availabilityFinalModel;
|
||||||
let currentGenerateContentConfig: GenerateContentConfig =
|
let currentGenerateContentConfig: GenerateContentConfig =
|
||||||
newAvailabilityConfig ?? generateContentConfig;
|
newAvailabilityConfig;
|
||||||
if (abortSignal) {
|
|
||||||
currentGenerateContentConfig = {
|
|
||||||
...currentGenerateContentConfig,
|
|
||||||
abortSignal,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
let lastConfig: GenerateContentConfig = currentGenerateContentConfig;
|
let lastConfig: GenerateContentConfig = currentGenerateContentConfig;
|
||||||
let lastContentsToUse: Content[] = requestContents;
|
let lastContentsToUse: Content[] = requestContents;
|
||||||
|
|
||||||
@@ -448,47 +439,27 @@ export class GeminiChat {
|
|||||||
this.config.getActiveModel(),
|
this.config.getActiveModel(),
|
||||||
this.config.getPreviewFeatures(),
|
this.config.getPreviewFeatures(),
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (modelToUse !== lastModelToUse) {
|
if (modelToUse !== lastModelToUse) {
|
||||||
const { generateContentConfig: newConfig } =
|
const { generateContentConfig: newConfig } =
|
||||||
this.config.modelConfigService.getResolvedConfig({
|
this.config.modelConfigService.getResolvedConfig({
|
||||||
model: modelToUse,
|
...modelConfigKey,
|
||||||
});
|
model: modelToUse,
|
||||||
currentGenerateContentConfig = {
|
});
|
||||||
...currentGenerateContentConfig,
|
currentGenerateContentConfig = newConfig;
|
||||||
...newConfig,
|
|
||||||
};
|
|
||||||
if (abortSignal) {
|
|
||||||
currentGenerateContentConfig.abortSignal = abortSignal;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastModelToUse = modelToUse;
|
lastModelToUse = modelToUse;
|
||||||
const config = {
|
const config: GenerateContentConfig = {
|
||||||
...currentGenerateContentConfig,
|
...currentGenerateContentConfig,
|
||||||
// TODO(12622): Ensure we don't overrwrite these when they are
|
// TODO(12622): Ensure we don't overrwrite these when they are
|
||||||
// passed via config.
|
// passed via config.
|
||||||
systemInstruction: this.systemInstruction,
|
systemInstruction: this.systemInstruction,
|
||||||
tools: this.tools,
|
tools: this.tools,
|
||||||
|
abortSignal,
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(joshualitt): Clean this up with model configs.
|
|
||||||
if (modelToUse.startsWith('gemini-3')) {
|
|
||||||
config.thinkingConfig = {
|
|
||||||
...config.thinkingConfig,
|
|
||||||
thinkingLevel: ThinkingLevel.HIGH,
|
|
||||||
};
|
|
||||||
delete config.thinkingConfig?.thinkingBudget;
|
|
||||||
} else {
|
|
||||||
// The `gemini-3` configs use thinkingLevel, so we have to invert the
|
|
||||||
// change above.
|
|
||||||
config.thinkingConfig = {
|
|
||||||
...config.thinkingConfig,
|
|
||||||
thinkingBudget: DEFAULT_THINKING_MODE,
|
|
||||||
};
|
|
||||||
delete config.thinkingConfig?.thinkingLevel;
|
|
||||||
}
|
|
||||||
let contentsToUse = isPreviewModel(modelToUse)
|
let contentsToUse = isPreviewModel(modelToUse)
|
||||||
? contentsForPreviewModel
|
? contentsForPreviewModel
|
||||||
: requestContents;
|
: requestContents;
|
||||||
@@ -576,10 +547,11 @@ export class GeminiChat {
|
|||||||
onPersistent429: onPersistent429Callback,
|
onPersistent429: onPersistent429Callback,
|
||||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||||
signal: generateContentConfig.abortSignal,
|
signal: abortSignal,
|
||||||
maxAttempts:
|
maxAttempts:
|
||||||
availabilityMaxAttempts ??
|
availabilityMaxAttempts ??
|
||||||
(this.config.isPreviewModelFallbackMode() && isPreviewModel(model)
|
(this.config.isPreviewModelFallbackMode() &&
|
||||||
|
isPreviewModel(lastModelToUse)
|
||||||
? 1
|
? 1
|
||||||
: undefined),
|
: undefined),
|
||||||
getAvailabilityContext,
|
getAvailabilityContext,
|
||||||
|
|||||||
Reference in New Issue
Block a user