feat(core): enhance availability routing with wrapped fallback and single-model policies (#13874)

This commit is contained in:
Adam Weidman
2025-12-01 12:41:06 -08:00
committed by GitHub
parent 806cd112ac
commit b4df7e351b
5 changed files with 65 additions and 24 deletions
@@ -72,8 +72,11 @@ export function getModelPolicyChain(
/** /**
* Provides a default policy scaffold for models not present in the catalog. * Provides a default policy scaffold for models not present in the catalog.
*/ */
export function createDefaultPolicy(model: string): ModelPolicy { export function createDefaultPolicy(
return definePolicy({ model }); model: string,
options?: { isLastResort?: boolean },
): ModelPolicy {
return definePolicy({ model, isLastResort: options?.isLastResort });
} }
export function validateModelPolicyChain(chain: ModelPolicyChain): void { export function validateModelPolicyChain(chain: ModelPolicyChain): void {
@@ -22,6 +22,7 @@ describe('policyHelpers', () => {
isInFallbackMode: () => false, isInFallbackMode: () => false,
} as unknown as Config; } as unknown as Config;
const chain = resolvePolicyChain(config); const chain = resolvePolicyChain(config);
expect(chain).toHaveLength(1);
expect(chain[0]?.model).toBe('custom-model'); expect(chain[0]?.model).toBe('custom-model');
}); });
@@ -46,7 +47,7 @@ describe('policyHelpers', () => {
]; ];
const context = buildFallbackPolicyContext(chain, 'b'); const context = buildFallbackPolicyContext(chain, 'b');
expect(context.failedPolicy?.model).toBe('b'); expect(context.failedPolicy?.model).toBe('b');
expect(context.candidates.map((p) => p.model)).toEqual(['c']); expect(context.candidates.map((p) => p.model)).toEqual(['c', 'a']);
}); });
it('returns full chain when model is not in policy list', () => { it('returns full chain when model is not in policy list', () => {
@@ -34,7 +34,9 @@ export function resolvePolicyChain(config: Config): ModelPolicyChain {
return chain; return chain;
} }
return [createDefaultPolicy(activeModel), ...chain]; // If the user specified a model not in the default chain, we assume they want
// *only* that model. We do not fallback to the default chain.
return [createDefaultPolicy(activeModel, { isLastResort: true })];
} }
/** /**
@@ -52,9 +54,11 @@ export function buildFallbackPolicyContext(
if (index === -1) { if (index === -1) {
return { failedPolicy: undefined, candidates: chain }; return { failedPolicy: undefined, candidates: chain };
} }
// Return [candidates_after, candidates_before] to prioritize downgrades
// (continuing the chain) before wrapping around to upgrades.
return { return {
failedPolicy: chain[index], failedPolicy: chain[index],
candidates: chain.slice(index + 1), candidates: [...chain.slice(index + 1), ...chain.slice(0, index)],
}; };
} }
+46 -6
View File
@@ -539,7 +539,7 @@ describe('handleFallback', () => {
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
availability = createAvailabilityMock({ availability = createAvailabilityMock({
selectedModel: 'gemini-1.5-flash', selectedModel: DEFAULT_GEMINI_FLASH_MODEL,
skipped: [], skipped: [],
}); });
policyHandler = vi.fn().mockResolvedValue('retry_once'); policyHandler = vi.fn().mockResolvedValue('retry_once');
@@ -556,9 +556,16 @@ describe('handleFallback', () => {
); );
}); });
it('uses availability selection when enabled', async () => { it('uses availability selection with correct candidates when enabled', async () => {
await handleFallback(policyConfig, MOCK_PRO_MODEL, AUTH_OAUTH); vi.spyOn(policyConfig, 'getPreviewFeatures').mockReturnValue(true);
expect(availability.selectFirstAvailable).toHaveBeenCalled(); vi.spyOn(policyConfig, 'getModel').mockReturnValue(DEFAULT_GEMINI_MODEL);
await handleFallback(policyConfig, DEFAULT_GEMINI_MODEL, AUTH_OAUTH);
expect(availability.selectFirstAvailable).toHaveBeenCalledWith([
DEFAULT_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
]);
}); });
it('falls back to last resort when availability returns null', async () => { it('falls back to last resort when availability returns null', async () => {
@@ -611,6 +618,33 @@ describe('handleFallback', () => {
} }
}); });
it('wraps around to upgrade candidates if the current model was selected mid-chain (e.g. by router)', async () => {
// Last-resort failure (Flash) in [Preview, Pro, Flash] checks Preview then Pro (all upstream).
vi.spyOn(policyConfig, 'getPreviewFeatures').mockReturnValue(true);
availability.selectFirstAvailable = vi.fn().mockReturnValue({
selectedModel: MOCK_PRO_MODEL,
skipped: [],
});
policyHandler.mockResolvedValue('retry_once');
await handleFallback(
policyConfig,
DEFAULT_GEMINI_FLASH_MODEL,
AUTH_OAUTH,
);
expect(availability.selectFirstAvailable).toHaveBeenCalledWith([
PREVIEW_GEMINI_MODEL,
MOCK_PRO_MODEL,
]);
expect(policyHandler).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
MOCK_PRO_MODEL,
undefined,
);
});
it('logs and returns null when handler resolves to null', async () => { it('logs and returns null when handler resolves to null', async () => {
policyHandler.mockResolvedValue(null); policyHandler.mockResolvedValue(null);
const debugLoggerErrorSpy = vi.spyOn(debugLogger, 'error'); const debugLoggerErrorSpy = vi.spyOn(debugLogger, 'error');
@@ -656,7 +690,12 @@ describe('handleFallback', () => {
); );
}); });
it('short-circuits when the failed model is already the last-resort policy', async () => { it('short-circuits when the failed model is the last-resort policy AND candidates are unavailable', async () => {
// Ensure short-circuit when wrapping to an unavailable upstream model.
availability.selectFirstAvailable = vi
.fn()
.mockReturnValue({ selectedModel: null, skipped: [] });
const result = await handleFallback( const result = await handleFallback(
policyConfig, policyConfig,
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
@@ -664,7 +703,8 @@ describe('handleFallback', () => {
); );
expect(result).toBeNull(); expect(result).toBeNull();
expect(policyConfig.getModelAvailabilityService).not.toHaveBeenCalled(); // Service called to check upstream; no UI handler since nothing selected.
expect(policyConfig.getModelAvailabilityService).toHaveBeenCalled();
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled(); expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
}); });
}); });
+6 -13
View File
@@ -135,20 +135,13 @@ async function handlePolicyDrivenFallback(
candidates.map((policy) => policy.model), candidates.map((policy) => policy.model),
); );
let lastResortPolicy = candidates.find((policy) => policy.isLastResort); const lastResortPolicy = candidates.find((policy) => policy.isLastResort);
if (!lastResortPolicy) { const fallbackModel = selection.selectedModel ?? lastResortPolicy?.model;
debugLogger.warn( const selectedPolicy = candidates.find(
'No isLastResort policy found in candidates, using last candidate as fallback.', (policy) => policy.model === fallbackModel,
); );
lastResortPolicy = candidates[candidates.length - 1];
}
const fallbackModel = selection.selectedModel ?? lastResortPolicy.model; if (!fallbackModel || fallbackModel === failedModel || !selectedPolicy) {
const selectedPolicy =
candidates.find((policy) => policy.model === fallbackModel) ??
lastResortPolicy;
if (!fallbackModel || fallbackModel === failedModel) {
return null; return null;
} }