feat(core): Introduce AgentLoopContext. (#21198)

This commit is contained in:
joshualitt
2026-03-09 09:02:20 -07:00
committed by GitHub
parent 7837194ab5
commit 96b939f63a
14 changed files with 196 additions and 66 deletions
@@ -30,7 +30,7 @@ describe('agent-scheduler', () => {
} as unknown as Mocked<ToolRegistry>;
mockConfig = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
toolRegistry: mockToolRegistry,
} as unknown as Mocked<Config>;
});
@@ -69,6 +69,6 @@ describe('agent-scheduler', () => {
// Verify that the scheduler's config has the overridden tool registry
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
expect(schedulerConfig.toolRegistry).toBe(mockToolRegistry);
});
});
@@ -0,0 +1,27 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { GeminiClient } from '../core/client.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
/**
* AgentLoopContext represents the execution-scoped view of the world for a single
* agent turn or sub-agent loop.
*/
export interface AgentLoopContext {
/** The unique ID for the current user turn or agent thought loop. */
readonly promptId: string;
/** The registry of tools available to the agent in this context. */
readonly toolRegistry: ToolRegistry;
/** The bus for user confirmations and messages in this context. */
readonly messageBus: MessageBus;
/** The client used to communicate with the LLM in this context. */
readonly geminiClient: GeminiClient;
}
+74 -49
View File
@@ -96,6 +96,7 @@ import type {
import { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import { ModelRouterService } from '../routing/modelRouterService.js';
import { OutputFormat } from '../output/types.js';
//import { type AgentLoopContext } from './agent-loop-context.js';
import {
ModelConfigService,
type ModelConfig,
@@ -154,6 +155,7 @@ import { CheckerRunner } from '../safety/checker-runner.js';
import { ContextBuilder } from '../safety/context-builder.js';
import { CheckerRegistry } from '../safety/registry.js';
import { ConsecaSafetyChecker } from '../safety/conseca/conseca.js';
import type { AgentLoopContext } from './agent-loop-context.js';
export interface AccessibilitySettings {
/** @deprecated Use ui.loadingPhrases instead. */
@@ -598,8 +600,8 @@ export interface ConfigParameters {
};
}
export class Config implements McpContext {
private toolRegistry!: ToolRegistry;
export class Config implements McpContext, AgentLoopContext {
private _toolRegistry!: ToolRegistry;
private mcpClientManager?: McpClientManager;
private allowedMcpServers: string[];
private blockedMcpServers: string[];
@@ -611,7 +613,7 @@ export class Config implements McpContext {
private agentRegistry!: AgentRegistry;
private readonly acknowledgedAgentsService: AcknowledgedAgentsService;
private skillManager!: SkillManager;
private sessionId: string;
private _sessionId: string;
private clientVersion: string;
private fileSystemService: FileSystemService;
private trackerService?: TrackerService;
@@ -645,7 +647,7 @@ export class Config implements McpContext {
private readonly accessibility: AccessibilitySettings;
private readonly telemetrySettings: TelemetrySettings;
private readonly usageStatisticsEnabled: boolean;
private geminiClient!: GeminiClient;
private _geminiClient!: GeminiClient;
private baseLlmClient!: BaseLlmClient;
private localLiteRtLmClient?: LocalLiteRtLmClient;
private modelRouterService: ModelRouterService;
@@ -740,7 +742,7 @@ export class Config implements McpContext {
private readonly fileExclusions: FileExclusions;
private readonly eventEmitter?: EventEmitter;
private readonly useWriteTodos: boolean;
private readonly messageBus: MessageBus;
private readonly _messageBus: MessageBus;
private readonly policyEngine: PolicyEngine;
private policyUpdateConfirmationRequest:
| PolicyUpdateConfirmationRequest
@@ -806,7 +808,7 @@ export class Config implements McpContext {
private approvedPlanPath: string | undefined;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
this._sessionId = params.sessionId;
this.clientVersion = params.clientVersion ?? 'unknown';
this.approvedPlanPath = undefined;
this.embeddingModel =
@@ -961,7 +963,7 @@ export class Config implements McpContext {
(params.shellToolInactivityTimeout ?? 300) * 1000; // 5 minutes
this.extensionManagement = params.extensionManagement ?? true;
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
this.storage = new Storage(this.targetDir, this.sessionId);
this.storage = new Storage(this.targetDir, this._sessionId);
this.storage.setCustomPlansDir(params.planSettings?.directory);
this.fakeResponses = params.fakeResponses;
@@ -997,7 +999,7 @@ export class Config implements McpContext {
ConsecaSafetyChecker.getInstance().setConfig(this);
}
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this._messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.acknowledgedAgentsService = new AcknowledgedAgentsService();
this.skillManager = new SkillManager();
this.outputSettings = {
@@ -1057,7 +1059,7 @@ export class Config implements McpContext {
);
}
}
this.geminiClient = new GeminiClient(this);
this._geminiClient = new GeminiClient(this);
this.modelRouterService = new ModelRouterService(this);
// HACK: The settings loading logic doesn't currently merge the default
@@ -1142,11 +1144,11 @@ export class Config implements McpContext {
coreEvents.on(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
this.toolRegistry = await this.createToolRegistry();
this._toolRegistry = await this.createToolRegistry();
discoverToolsHandle?.end();
this.mcpClientManager = new McpClientManager(
this.clientVersion,
this.toolRegistry,
this._toolRegistry,
this,
this.eventEmitter,
);
@@ -1181,7 +1183,7 @@ export class Config implements McpContext {
if (this.getSkillManager().getSkills().length > 0) {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
this.getToolRegistry().registerTool(
new ActivateSkillTool(this, this.messageBus),
new ActivateSkillTool(this, this._messageBus),
);
}
}
@@ -1198,7 +1200,7 @@ export class Config implements McpContext {
await this.contextManager.refresh();
}
await this.geminiClient.initialize();
await this._geminiClient.initialize();
this.initialized = true;
}
@@ -1222,7 +1224,7 @@ export class Config implements McpContext {
authMethod !== AuthType.USE_GEMINI
) {
// Restore the conversation history to the new client
this.geminiClient.stripThoughtsFromHistory();
this._geminiClient.stripThoughtsFromHistory();
}
// Reset availability status when switching auth (e.g. from limited key to OAuth)
@@ -1343,12 +1345,28 @@ export class Config implements McpContext {
return this.localLiteRtLmClient;
}
get promptId(): string {
return this._sessionId;
}
get toolRegistry(): ToolRegistry {
return this._toolRegistry;
}
get messageBus(): MessageBus {
return this._messageBus;
}
get geminiClient(): GeminiClient {
return this._geminiClient;
}
getSessionId(): string {
return this.sessionId;
return this.promptId;
}
setSessionId(sessionId: string): void {
this.sessionId = sessionId;
this._sessionId = sessionId;
}
setTerminalBackground(terminalBackground: string | undefined): void {
@@ -1613,6 +1631,7 @@ export class Config implements McpContext {
return this.acknowledgedAgentsService;
}
/** @deprecated Use toolRegistry getter */
getToolRegistry(): ToolRegistry {
return this.toolRegistry;
}
@@ -1889,9 +1908,9 @@ export class Config implements McpContext {
);
await refreshServerHierarchicalMemory(this);
}
if (this.geminiClient?.isInitialized()) {
await this.geminiClient.setTools();
this.geminiClient.updateSystemInstruction();
if (this._geminiClient?.isInitialized()) {
await this._geminiClient.setTools();
this._geminiClient.updateSystemInstruction();
}
}
@@ -2045,8 +2064,8 @@ export class Config implements McpContext {
(currentMode === ApprovalMode.YOLO || mode === ApprovalMode.YOLO);
if (isPlanModeTransition || isYoloModeTransition) {
if (this.geminiClient?.isInitialized()) {
this.geminiClient.setTools().catch((err) => {
if (this._geminiClient?.isInitialized()) {
this._geminiClient.setTools().catch((err) => {
debugLogger.error('Failed to update tools', err);
});
}
@@ -2142,6 +2161,7 @@ export class Config implements McpContext {
return this.telemetrySettings.useCliAuth ?? false;
}
/** @deprecated Use geminiClient getter */
getGeminiClient(): GeminiClient {
return this.geminiClient;
}
@@ -2577,7 +2597,7 @@ export class Config implements McpContext {
if (this.getSkillManager().getSkills().length > 0) {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
this.getToolRegistry().registerTool(
new ActivateSkillTool(this, this.messageBus),
new ActivateSkillTool(this, this._messageBus),
);
} else {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
@@ -2703,6 +2723,7 @@ export class Config implements McpContext {
return this.fileExclusions;
}
/** @deprecated Use messageBus getter */
getMessageBus(): MessageBus {
return this.messageBus;
}
@@ -2760,7 +2781,7 @@ export class Config implements McpContext {
}
async createToolRegistry(): Promise<ToolRegistry> {
const registry = new ToolRegistry(this, this.messageBus);
const registry = new ToolRegistry(this, this._messageBus);
// helper to create & register core tools that are enabled
const maybeRegister = (
@@ -2790,10 +2811,10 @@ export class Config implements McpContext {
};
maybeRegister(LSTool, () =>
registry.registerTool(new LSTool(this, this.messageBus)),
registry.registerTool(new LSTool(this, this._messageBus)),
);
maybeRegister(ReadFileTool, () =>
registry.registerTool(new ReadFileTool(this, this.messageBus)),
registry.registerTool(new ReadFileTool(this, this._messageBus)),
);
if (this.getUseRipgrep()) {
@@ -2806,81 +2827,85 @@ export class Config implements McpContext {
}
if (useRipgrep) {
maybeRegister(RipGrepTool, () =>
registry.registerTool(new RipGrepTool(this, this.messageBus)),
registry.registerTool(new RipGrepTool(this, this._messageBus)),
);
} else {
logRipgrepFallback(this, new RipgrepFallbackEvent(errorString));
maybeRegister(GrepTool, () =>
registry.registerTool(new GrepTool(this, this.messageBus)),
registry.registerTool(new GrepTool(this, this._messageBus)),
);
}
} else {
maybeRegister(GrepTool, () =>
registry.registerTool(new GrepTool(this, this.messageBus)),
registry.registerTool(new GrepTool(this, this._messageBus)),
);
}
maybeRegister(GlobTool, () =>
registry.registerTool(new GlobTool(this, this.messageBus)),
registry.registerTool(new GlobTool(this, this._messageBus)),
);
maybeRegister(ActivateSkillTool, () =>
registry.registerTool(new ActivateSkillTool(this, this.messageBus)),
registry.registerTool(new ActivateSkillTool(this, this._messageBus)),
);
maybeRegister(EditTool, () =>
registry.registerTool(new EditTool(this, this.messageBus)),
registry.registerTool(new EditTool(this, this._messageBus)),
);
maybeRegister(WriteFileTool, () =>
registry.registerTool(new WriteFileTool(this, this.messageBus)),
registry.registerTool(new WriteFileTool(this, this._messageBus)),
);
maybeRegister(WebFetchTool, () =>
registry.registerTool(new WebFetchTool(this, this.messageBus)),
registry.registerTool(new WebFetchTool(this, this._messageBus)),
);
maybeRegister(ShellTool, () =>
registry.registerTool(new ShellTool(this, this.messageBus)),
registry.registerTool(new ShellTool(this, this._messageBus)),
);
maybeRegister(MemoryTool, () =>
registry.registerTool(new MemoryTool(this.messageBus)),
registry.registerTool(new MemoryTool(this._messageBus)),
);
maybeRegister(WebSearchTool, () =>
registry.registerTool(new WebSearchTool(this, this.messageBus)),
registry.registerTool(new WebSearchTool(this, this._messageBus)),
);
maybeRegister(AskUserTool, () =>
registry.registerTool(new AskUserTool(this.messageBus)),
registry.registerTool(new AskUserTool(this._messageBus)),
);
if (this.getUseWriteTodos()) {
maybeRegister(WriteTodosTool, () =>
registry.registerTool(new WriteTodosTool(this.messageBus)),
registry.registerTool(new WriteTodosTool(this._messageBus)),
);
}
if (this.isPlanEnabled()) {
maybeRegister(ExitPlanModeTool, () =>
registry.registerTool(new ExitPlanModeTool(this, this.messageBus)),
registry.registerTool(new ExitPlanModeTool(this, this._messageBus)),
);
maybeRegister(EnterPlanModeTool, () =>
registry.registerTool(new EnterPlanModeTool(this, this.messageBus)),
registry.registerTool(new EnterPlanModeTool(this, this._messageBus)),
);
}
if (this.isTrackerEnabled()) {
maybeRegister(TrackerCreateTaskTool, () =>
registry.registerTool(new TrackerCreateTaskTool(this, this.messageBus)),
registry.registerTool(
new TrackerCreateTaskTool(this, this._messageBus),
),
);
maybeRegister(TrackerUpdateTaskTool, () =>
registry.registerTool(new TrackerUpdateTaskTool(this, this.messageBus)),
registry.registerTool(
new TrackerUpdateTaskTool(this, this._messageBus),
),
);
maybeRegister(TrackerGetTaskTool, () =>
registry.registerTool(new TrackerGetTaskTool(this, this.messageBus)),
registry.registerTool(new TrackerGetTaskTool(this, this._messageBus)),
);
maybeRegister(TrackerListTasksTool, () =>
registry.registerTool(new TrackerListTasksTool(this, this.messageBus)),
registry.registerTool(new TrackerListTasksTool(this, this._messageBus)),
);
maybeRegister(TrackerAddDependencyTool, () =>
registry.registerTool(
new TrackerAddDependencyTool(this, this.messageBus),
new TrackerAddDependencyTool(this, this._messageBus),
),
);
maybeRegister(TrackerVisualizeTool, () =>
registry.registerTool(new TrackerVisualizeTool(this, this.messageBus)),
registry.registerTool(new TrackerVisualizeTool(this, this._messageBus)),
);
}
@@ -3007,8 +3032,8 @@ export class Config implements McpContext {
}
private onAgentsRefreshed = async () => {
if (this.toolRegistry) {
this.registerSubAgentTools(this.toolRegistry);
if (this._toolRegistry) {
this.registerSubAgentTools(this._toolRegistry);
}
// Propagate updates to the active chat session
const client = this.getGeminiClient();
@@ -3029,7 +3054,7 @@ export class Config implements McpContext {
this.logCurrentModeDuration(this.getApprovalMode());
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
this.agentRegistry?.dispose();
this.geminiClient?.dispose();
this._geminiClient?.dispose();
if (this.mcpClientManager) {
await this.mcpClientManager.stop();
}
@@ -290,6 +290,8 @@ function createMockConfig(overrides: Partial<Config> = {}): Config {
const finalConfig = { ...baseConfig, ...overrides } as Config;
(finalConfig as unknown as { config: Config }).config = finalConfig;
// Patch the policy engine to use the final config if not overridden
if (!overrides.getPolicyEngine) {
finalConfig.getPolicyEngine = () =>
+1 -1
View File
@@ -133,7 +133,7 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
this.getPreferredEditor = options.getPreferredEditor;
this.toolExecutor = new ToolExecutor(this.config);
this.toolExecutor = new ToolExecutor(this.config, this.config);
this.toolModifier = new ToolModificationHandler();
// Subscribe to message bus for ASK_USER policy decisions
+1
View File
@@ -6,6 +6,7 @@
// Export config
export * from './config/config.js';
export * from './config/agent-loop-context.js';
export * from './config/memory.js';
export * from './config/defaultModelConfigs.js';
export * from './config/models.js';
+55 -1
View File
@@ -49,6 +49,9 @@ describe('policy.ts', () => {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const toolCall = {
request: { name: 'test-tool', args: {} },
tool: { name: 'test-tool' },
@@ -72,6 +75,9 @@ describe('policy.ts', () => {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mcpTool = Object.create(DiscoveredMCPTool.prototype);
mcpTool.serverName = 'my-server';
mcpTool._toolAnnotations = { readOnlyHint: true };
@@ -99,6 +105,9 @@ describe('policy.ts', () => {
isInteractive: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const toolCall = {
request: { name: 'test-tool', args: {} },
tool: { name: 'test-tool' },
@@ -118,6 +127,9 @@ describe('policy.ts', () => {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const toolCall = {
request: { name: 'test-tool', args: {} },
tool: { name: 'test-tool' },
@@ -137,6 +149,9 @@ describe('policy.ts', () => {
isInteractive: vi.fn().mockReturnValue(true),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const toolCall = {
request: { name: 'test-tool', args: {} },
tool: { name: 'test-tool' },
@@ -152,6 +167,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -175,6 +193,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -200,6 +221,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -225,6 +249,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -256,6 +283,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -290,6 +320,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -308,6 +341,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -325,6 +361,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -344,6 +383,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -378,6 +420,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -410,6 +455,9 @@ describe('policy.ts', () => {
const mockConfig = {
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config =
mockConfig as Config;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
@@ -447,6 +495,8 @@ describe('policy.ts', () => {
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Config;
(mockConfig as unknown as { config: Config }).config = mockConfig;
const { errorMessage, errorType } = getPolicyDenialError(mockConfig);
expect(errorMessage).toBe('Tool execution denied by policy.');
@@ -457,6 +507,8 @@ describe('policy.ts', () => {
const mockConfig = {
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Config;
(mockConfig as unknown as { config: Config }).config = mockConfig;
const rule = {
decision: PolicyDecision.DENY,
denyMessage: 'Custom Deny',
@@ -517,7 +569,8 @@ describe('Plan Mode Denial Consistency', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
toolRegistry: mockToolRegistry,
getToolRegistry: () => mockToolRegistry,
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(false),
@@ -525,6 +578,7 @@ describe('Plan Mode Denial Consistency', () => {
setApprovalMode: vi.fn(),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
});
afterEach(() => {
@@ -169,13 +169,15 @@ describe('Scheduler (Orchestrator)', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
toolRegistry: mockToolRegistry,
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
@@ -1320,6 +1322,8 @@ describe('Scheduler MCP Progress', () => {
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
+9 -6
View File
@@ -5,6 +5,7 @@
*/
import type { Config } from '../config/config.js';
import type { AgentLoopContext } from '../config/agent-loop-context.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { SchedulerStateManager } from './state-manager.js';
import { resolveConfirmation } from './confirmation.js';
@@ -57,7 +58,7 @@ interface SchedulerQueueItem {
export interface SchedulerOptions {
config: Config;
messageBus: MessageBus;
messageBus?: MessageBus;
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
parentCallId?: string;
@@ -97,6 +98,7 @@ export class Scheduler {
private readonly executor: ToolExecutor;
private readonly modifier: ToolModificationHandler;
private readonly config: Config;
private readonly context: AgentLoopContext;
private readonly messageBus: MessageBus;
private readonly getPreferredEditor: () => EditorType | undefined;
private readonly schedulerId: string;
@@ -109,7 +111,8 @@ export class Scheduler {
constructor(options: SchedulerOptions) {
this.config = options.config;
this.messageBus = options.messageBus;
this.context = options.config;
this.messageBus = options.messageBus ?? this.context.messageBus;
this.getPreferredEditor = options.getPreferredEditor;
this.schedulerId = options.schedulerId;
this.parentCallId = options.parentCallId;
@@ -119,7 +122,7 @@ export class Scheduler {
this.schedulerId,
(call) => logToolCall(this.config, new ToolCallEvent(call)),
);
this.executor = new ToolExecutor(this.config);
this.executor = new ToolExecutor(this.config, this.context);
this.modifier = new ToolModificationHandler();
this.setupMessageBusListener(this.messageBus);
@@ -294,7 +297,7 @@ export class Scheduler {
const currentApprovalMode = this.config.getApprovalMode();
try {
const toolRegistry = this.config.getToolRegistry();
const toolRegistry = this.context.toolRegistry;
const newCalls: ToolCall[] = requests.map((request) => {
const enrichedRequest: ToolCallRequestInfo = {
...request,
@@ -697,7 +700,7 @@ export class Scheduler {
const originalRequestName =
result.request.originalRequestName || result.request.name;
const newTool = this.config.getToolRegistry().getTool(tailRequest.name);
const newTool = this.context.toolRegistry.getTool(tailRequest.name);
const newRequest: ToolCallRequestInfo = {
callId: originalCallId,
@@ -713,7 +716,7 @@ export class Scheduler {
// Enqueue an errored tool call
const errorCall = this._createToolNotFoundErroredToolCall(
newRequest,
this.config.getToolRegistry().getAllToolNames(),
this.context.toolRegistry.getAllToolNames(),
);
this.state.replaceActiveCallWithTailCall(callId, errorCall);
} else {
@@ -211,13 +211,15 @@ describe('Scheduler Parallel Execution', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
toolRegistry: mockToolRegistry,
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>;
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
@@ -36,7 +36,7 @@ describe('Scheduler waiting callback', () => {
mockTool = new MockTool({ name: 'test_tool' });
toolRegistry = new ToolRegistry(mockConfig, messageBus);
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry);
vi.spyOn(mockConfig, 'toolRegistry', 'get').mockReturnValue(toolRegistry);
toolRegistry.registerTool(mockTool);
vi.mocked(checkPolicy).mockResolvedValue({
@@ -64,7 +64,7 @@ describe('ToolExecutor', () => {
beforeEach(() => {
// Use the standard fake config factory
config = makeFakeConfig();
executor = new ToolExecutor(config);
executor = new ToolExecutor(config, config);
// Reset mocks
vi.resetAllMocks();
+7 -3
View File
@@ -13,6 +13,7 @@ import {
type ToolCallResponseInfo,
type ToolResult,
type Config,
type AgentLoopContext,
type ToolLiveOutput,
} from '../index.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
@@ -49,7 +50,10 @@ export interface ToolExecutionContext {
}
export class ToolExecutor {
constructor(private readonly config: Config) {}
constructor(
private readonly config: Config,
private readonly context: AgentLoopContext,
) {}
async execute(context: ToolExecutionContext): Promise<CompletedToolCall> {
const { call, signal, outputUpdateHandler, onUpdateToolCall } = context;
@@ -202,7 +206,7 @@ export class ToolExecutor {
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.config.getSessionId(),
this.context.promptId,
);
outputFile = savedPath;
const truncatedContent = formatTruncatedToolOutput(
@@ -241,7 +245,7 @@ export class ToolExecutor {
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.config.getSessionId(),
this.context.promptId,
);
outputFile = savedPath;
const truncatedText = formatTruncatedToolOutput(