Merge branch 'main' into feat/browser-allowed-domain

This commit is contained in:
cynthialong0-0
2026-03-10 06:51:21 -07:00
committed by GitHub
23 changed files with 542 additions and 90 deletions
@@ -19,23 +19,24 @@ vi.mock('../scheduler/scheduler.js', () => ({
}));
describe('agent-scheduler', () => {
let mockConfig: Mocked<Config>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockMessageBus: Mocked<MessageBus>;
beforeEach(() => {
vi.mocked(Scheduler).mockClear();
mockMessageBus = {} as Mocked<MessageBus>;
mockToolRegistry = {
getTool: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
} as unknown as Mocked<ToolRegistry>;
mockConfig = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
toolRegistry: mockToolRegistry,
} as unknown as Mocked<Config>;
});
it('should create a scheduler with agent-specific config', async () => {
const mockConfig = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
toolRegistry: mockToolRegistry,
} as unknown as Mocked<Config>;
const requests: ToolCallRequestInfo[] = [
{
callId: 'call-1',
@@ -68,8 +69,46 @@ describe('agent-scheduler', () => {
}),
);
// Verify that the scheduler's config has the overridden tool registry
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
expect(schedulerConfig.toolRegistry).toBe(mockToolRegistry);
});
it('should override toolRegistry getter from prototype chain', async () => {
const mainRegistry = { _id: 'main' } as unknown as Mocked<ToolRegistry>;
const agentRegistry = {
_id: 'agent',
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
} as unknown as Mocked<ToolRegistry>;
const config = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
} as unknown as Mocked<Config>;
Object.defineProperty(config, 'toolRegistry', {
get: () => mainRegistry,
configurable: true,
});
await scheduleAgentTools(
config as unknown as Config,
[
{
callId: 'c1',
name: 'new_page',
args: {},
isClientInitiated: false,
prompt_id: 'p1',
},
],
{
schedulerId: 'browser-1',
toolRegistry: agentRegistry as unknown as ToolRegistry,
signal: new AbortController().signal,
},
);
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
expect(schedulerConfig.toolRegistry).toBe(agentRegistry);
expect(schedulerConfig.toolRegistry).not.toBe(mainRegistry);
expect(schedulerConfig.getToolRegistry()).toBe(agentRegistry);
});
});
@@ -58,6 +58,11 @@ export async function scheduleAgentTools(
const agentConfig: Config = Object.create(config);
agentConfig.getToolRegistry = () => toolRegistry;
agentConfig.getMessageBus = () => toolRegistry.getMessageBus();
// Override toolRegistry property so AgentLoopContext reads the agent-specific registry.
Object.defineProperty(agentConfig, 'toolRegistry', {
get: () => toolRegistry,
configurable: true,
});
const scheduler = new Scheduler({
config: agentConfig,
@@ -210,6 +210,7 @@ describe('browserAgentFactory', () => {
expect(toolNames).toContain('analyze_screenshot');
});
<<<<<<< feat/browser-allowed-domain
it('should include domain restrictions in system prompt when configured', async () => {
const configWithDomains = makeFakeConfig({
agents: {
@@ -227,6 +228,45 @@ describe('browserAgentFactory', () => {
const systemPrompt = definition.promptConfig?.systemPrompt ?? '';
expect(systemPrompt).toContain('SECURITY DOMAIN RESTRICTION - CRITICAL:');
expect(systemPrompt).toContain('- restricted.com');
=======
it('should include all MCP navigation tools (new_page, navigate_page) in definition', async () => {
mockBrowserManager.getDiscoveredTools.mockResolvedValue([
{ name: 'take_snapshot', description: 'Take snapshot' },
{ name: 'click', description: 'Click element' },
{ name: 'fill', description: 'Fill form field' },
{ name: 'navigate_page', description: 'Navigate to URL' },
{ name: 'new_page', description: 'Open a new page/tab' },
{ name: 'close_page', description: 'Close page' },
{ name: 'select_page', description: 'Select page' },
{ name: 'press_key', description: 'Press key' },
{ name: 'hover', description: 'Hover element' },
]);
const { definition } = await createBrowserAgentDefinition(
mockConfig,
mockMessageBus,
);
const toolNames =
definition.toolConfig?.tools
?.filter(
(t): t is { name: string } => typeof t === 'object' && 'name' in t,
)
.map((t) => t.name) ?? [];
// All MCP tools must be present
expect(toolNames).toContain('new_page');
expect(toolNames).toContain('navigate_page');
expect(toolNames).toContain('close_page');
expect(toolNames).toContain('select_page');
expect(toolNames).toContain('click');
expect(toolNames).toContain('take_snapshot');
expect(toolNames).toContain('press_key');
// Custom composite tool must also be present
expect(toolNames).toContain('type_text');
// Total: 9 MCP + 1 type_text (no analyze_screenshot without visualModel)
expect(definition.toolConfig?.tools).toHaveLength(10);
>>>>>>> main
});
});
@@ -20,14 +20,21 @@ import {
} from '../config/models.js';
import { AuthType } from '../core/contentGenerator.js';
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
const createMockConfig = (overrides: Partial<Config> = {}): Config => {
const config = {
getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro',
getGemini31LaunchedSync: () => false,
getUseCustomToolModelSync: () => {
const useGemini31 = config.getGemini31LaunchedSync();
const authType = config.getContentGeneratorConfig().authType;
return useGemini31 && authType === AuthType.USE_GEMINI;
},
getContentGeneratorConfig: () => ({ authType: undefined }),
...overrides,
}) as unknown as Config;
} as unknown as Config;
return config;
};
describe('policyHelpers', () => {
describe('resolvePolicyChain', () => {
@@ -6,7 +6,6 @@
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import type {
FailureKind,
FallbackAction,
@@ -46,9 +45,7 @@ export function resolvePolicyChain(
let chain;
const useGemini31 = config.getGemini31LaunchedSync?.() ?? false;
const useCustomToolModel =
useGemini31 &&
config.getContentGeneratorConfig?.()?.authType === AuthType.USE_GEMINI;
const useCustomToolModel = config.getUseCustomToolModelSync?.() ?? false;
const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true;
const resolvedModel = resolveModel(
+20
View File
@@ -2531,6 +2531,26 @@ export class Config implements McpContext, AgentLoopContext {
return this.getGemini31LaunchedSync();
}
/**
* Returns whether the custom tool model should be used.
*/
async getUseCustomToolModel(): Promise<boolean> {
const useGemini3_1 = await this.getGemini31Launched();
const authType = this.contentGeneratorConfig?.authType;
return useGemini3_1 && authType === AuthType.USE_GEMINI;
}
/**
* Returns whether the custom tool model should be used.
*
* Note: This method should only be called after startup, once experiments have been loaded.
*/
getUseCustomToolModelSync(): boolean {
const useGemini3_1 = this.getGemini31LaunchedSync();
const authType = this.contentGeneratorConfig?.authType;
return useGemini3_1 && authType === AuthType.USE_GEMINI;
}
/**
* Returns whether Gemini 3.1 has been launched.
*
+2 -1
View File
@@ -168,7 +168,8 @@ export function isPreviewModel(model: string): boolean {
model === PREVIEW_GEMINI_3_1_MODEL ||
model === PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL ||
model === PREVIEW_GEMINI_FLASH_MODEL ||
model === PREVIEW_GEMINI_MODEL_AUTO
model === PREVIEW_GEMINI_MODEL_AUTO ||
model === GEMINI_MODEL_ALIAS_AUTO
);
}
@@ -15,7 +15,9 @@ import {
PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
GEMINI_MODEL_ALIAS_AUTO,
} from '../../config/models.js';
import { AuthType } from '../../core/contentGenerator.js';
import { ApprovalMode } from '../../policy/types.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
@@ -40,6 +42,15 @@ describe('ApprovalModeStrategy', () => {
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getApprovedPlanPath: vi.fn().mockReturnValue(undefined),
getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig?.()?.authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
} as unknown as Config;
mockBaseLlmClient = {} as BaseLlmClient;
@@ -184,4 +195,50 @@ describe('ApprovalModeStrategy', () => {
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
});
it('should route to Preview models when using "auto" alias', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);
const implementationDecision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(implementationDecision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
});
it('should route to Preview Flash model when an approved plan exists and Gemini 3.1 is launched', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
// Exit plan mode with approved plan
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
// Should resolve to Preview Flash (3.0) because resolveClassifierModel uses preview variants for Gemini 3
expect(decision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
});
});
@@ -6,12 +6,10 @@
import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
isAutoModel,
isPreviewModel,
resolveClassifierModel,
GEMINI_MODEL_ALIAS_FLASH,
GEMINI_MODEL_ALIAS_PRO,
} from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { ApprovalMode } from '../../policy/types.js';
@@ -50,11 +48,19 @@ export class ApprovalModeStrategy implements RoutingStrategy {
const approvalMode = config.getApprovalMode();
const approvedPlanPath = config.getApprovedPlanPath();
const isPreview = isPreviewModel(model);
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);
// 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model.
if (approvalMode === ApprovalMode.PLAN) {
const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL;
const proModel = resolveClassifierModel(
model,
GEMINI_MODEL_ALIAS_PRO,
useGemini3_1,
useCustomToolModel,
);
return {
model: proModel,
metadata: {
@@ -65,9 +71,12 @@ export class ApprovalModeStrategy implements RoutingStrategy {
};
} else if (approvedPlanPath) {
// 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model.
const flashModel = isPreview
? PREVIEW_GEMINI_FLASH_MODEL
: DEFAULT_GEMINI_FLASH_MODEL;
const flashModel = resolveClassifierModel(
model,
GEMINI_MODEL_ALIAS_FLASH,
useGemini3_1,
useCustomToolModel,
);
return {
model: flashModel,
metadata: {
@@ -59,6 +59,11 @@ describe('ClassifierStrategy', () => {
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig().authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
@@ -22,7 +22,6 @@ import {
import { debugLogger } from '../../utils/debugLogger.js';
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';
// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 4;
@@ -172,10 +171,10 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);
const selectedModel = resolveClassifierModel(
model,
routerResponse.model_choice,
@@ -58,6 +58,11 @@ describe('NumericalClassifierStrategy', () => {
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig().authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
@@ -18,7 +18,6 @@ import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js';
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';
// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 8;
@@ -185,10 +184,10 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config,
config.getSessionId() || 'unknown-session',
);
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);
const selectedModel = resolveClassifierModel(
model,
modelAlias,