diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 977daedf16..7d77d8dc9a 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -75,6 +75,14 @@ export function createMockConfig( validatePathAccess: vi.fn().mockReturnValue(undefined), ...overrides, } as unknown as Config; + + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + (mockConfig as unknown as { config: Config; promptId: string }).config = + mockConfig; + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + (mockConfig as unknown as { config: Config; promptId: string }).promptId = + 'test-prompt-id'; + mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus()); mockConfig.getHookSystem = vi .fn() diff --git a/packages/core/src/agents/agent-scheduler.test.ts b/packages/core/src/agents/agent-scheduler.test.ts index 5edcb664b6..dd6749d3a0 100644 --- a/packages/core/src/agents/agent-scheduler.test.ts +++ b/packages/core/src/agents/agent-scheduler.test.ts @@ -30,7 +30,7 @@ describe('agent-scheduler', () => { } as unknown as Mocked; mockConfig = { getMessageBus: vi.fn().mockReturnValue(mockMessageBus), - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + toolRegistry: mockToolRegistry, } as unknown as Mocked; }); @@ -69,6 +69,6 @@ describe('agent-scheduler', () => { // Verify that the scheduler's config has the overridden tool registry const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config; - expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry); + expect(schedulerConfig.toolRegistry).toBe(mockToolRegistry); }); }); diff --git a/packages/core/src/config/agent-loop-context.ts b/packages/core/src/config/agent-loop-context.ts new file mode 100644 index 0000000000..0a7334c334 --- /dev/null +++ b/packages/core/src/config/agent-loop-context.ts @@ -0,0 +1,27 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { GeminiClient } from '../core/client.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; + +/** + * AgentLoopContext represents the execution-scoped view of the world for a single + * agent turn or sub-agent loop. + */ +export interface AgentLoopContext { + /** The unique ID for the current user turn or agent thought loop. */ + readonly promptId: string; + + /** The registry of tools available to the agent in this context. */ + readonly toolRegistry: ToolRegistry; + + /** The bus for user confirmations and messages in this context. */ + readonly messageBus: MessageBus; + + /** The client used to communicate with the LLM in this context. */ + readonly geminiClient: GeminiClient; +} diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index adb36ca3a3..f615564533 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -96,6 +96,7 @@ import type { import { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import { ModelRouterService } from '../routing/modelRouterService.js'; import { OutputFormat } from '../output/types.js'; +//import { type AgentLoopContext } from './agent-loop-context.js'; import { ModelConfigService, type ModelConfig, @@ -154,6 +155,7 @@ import { CheckerRunner } from '../safety/checker-runner.js'; import { ContextBuilder } from '../safety/context-builder.js'; import { CheckerRegistry } from '../safety/registry.js'; import { ConsecaSafetyChecker } from '../safety/conseca/conseca.js'; +import type { AgentLoopContext } from './agent-loop-context.js'; export interface AccessibilitySettings { /** @deprecated Use ui.loadingPhrases instead. */ @@ -598,8 +600,8 @@ export interface ConfigParameters { }; } -export class Config implements McpContext { - private toolRegistry!: ToolRegistry; +export class Config implements McpContext, AgentLoopContext { + private _toolRegistry!: ToolRegistry; private mcpClientManager?: McpClientManager; private allowedMcpServers: string[]; private blockedMcpServers: string[]; @@ -611,7 +613,7 @@ export class Config implements McpContext { private agentRegistry!: AgentRegistry; private readonly acknowledgedAgentsService: AcknowledgedAgentsService; private skillManager!: SkillManager; - private sessionId: string; + private _sessionId: string; private clientVersion: string; private fileSystemService: FileSystemService; private trackerService?: TrackerService; @@ -645,7 +647,7 @@ export class Config implements McpContext { private readonly accessibility: AccessibilitySettings; private readonly telemetrySettings: TelemetrySettings; private readonly usageStatisticsEnabled: boolean; - private geminiClient!: GeminiClient; + private _geminiClient!: GeminiClient; private baseLlmClient!: BaseLlmClient; private localLiteRtLmClient?: LocalLiteRtLmClient; private modelRouterService: ModelRouterService; @@ -740,7 +742,7 @@ export class Config implements McpContext { private readonly fileExclusions: FileExclusions; private readonly eventEmitter?: EventEmitter; private readonly useWriteTodos: boolean; - private readonly messageBus: MessageBus; + private readonly _messageBus: MessageBus; private readonly policyEngine: PolicyEngine; private policyUpdateConfirmationRequest: | PolicyUpdateConfirmationRequest @@ -806,7 +808,7 @@ export class Config implements McpContext { private approvedPlanPath: string | undefined; constructor(params: ConfigParameters) { - this.sessionId = params.sessionId; + this._sessionId = params.sessionId; this.clientVersion = params.clientVersion ?? 'unknown'; this.approvedPlanPath = undefined; this.embeddingModel = @@ -961,7 +963,7 @@ export class Config implements McpContext { (params.shellToolInactivityTimeout ?? 300) * 1000; // 5 minutes this.extensionManagement = params.extensionManagement ?? true; this.enableExtensionReloading = params.enableExtensionReloading ?? false; - this.storage = new Storage(this.targetDir, this.sessionId); + this.storage = new Storage(this.targetDir, this._sessionId); this.storage.setCustomPlansDir(params.planSettings?.directory); this.fakeResponses = params.fakeResponses; @@ -997,7 +999,7 @@ export class Config implements McpContext { ConsecaSafetyChecker.getInstance().setConfig(this); } - this.messageBus = new MessageBus(this.policyEngine, this.debugMode); + this._messageBus = new MessageBus(this.policyEngine, this.debugMode); this.acknowledgedAgentsService = new AcknowledgedAgentsService(); this.skillManager = new SkillManager(); this.outputSettings = { @@ -1057,7 +1059,7 @@ export class Config implements McpContext { ); } } - this.geminiClient = new GeminiClient(this); + this._geminiClient = new GeminiClient(this); this.modelRouterService = new ModelRouterService(this); // HACK: The settings loading logic doesn't currently merge the default @@ -1142,11 +1144,11 @@ export class Config implements McpContext { coreEvents.on(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed); - this.toolRegistry = await this.createToolRegistry(); + this._toolRegistry = await this.createToolRegistry(); discoverToolsHandle?.end(); this.mcpClientManager = new McpClientManager( this.clientVersion, - this.toolRegistry, + this._toolRegistry, this, this.eventEmitter, ); @@ -1181,7 +1183,7 @@ export class Config implements McpContext { if (this.getSkillManager().getSkills().length > 0) { this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); this.getToolRegistry().registerTool( - new ActivateSkillTool(this, this.messageBus), + new ActivateSkillTool(this, this._messageBus), ); } } @@ -1198,7 +1200,7 @@ export class Config implements McpContext { await this.contextManager.refresh(); } - await this.geminiClient.initialize(); + await this._geminiClient.initialize(); this.initialized = true; } @@ -1222,7 +1224,7 @@ export class Config implements McpContext { authMethod !== AuthType.USE_GEMINI ) { // Restore the conversation history to the new client - this.geminiClient.stripThoughtsFromHistory(); + this._geminiClient.stripThoughtsFromHistory(); } // Reset availability status when switching auth (e.g. from limited key to OAuth) @@ -1343,12 +1345,28 @@ export class Config implements McpContext { return this.localLiteRtLmClient; } + get promptId(): string { + return this._sessionId; + } + + get toolRegistry(): ToolRegistry { + return this._toolRegistry; + } + + get messageBus(): MessageBus { + return this._messageBus; + } + + get geminiClient(): GeminiClient { + return this._geminiClient; + } + getSessionId(): string { - return this.sessionId; + return this.promptId; } setSessionId(sessionId: string): void { - this.sessionId = sessionId; + this._sessionId = sessionId; } setTerminalBackground(terminalBackground: string | undefined): void { @@ -1613,6 +1631,7 @@ export class Config implements McpContext { return this.acknowledgedAgentsService; } + /** @deprecated Use toolRegistry getter */ getToolRegistry(): ToolRegistry { return this.toolRegistry; } @@ -1889,9 +1908,9 @@ export class Config implements McpContext { ); await refreshServerHierarchicalMemory(this); } - if (this.geminiClient?.isInitialized()) { - await this.geminiClient.setTools(); - this.geminiClient.updateSystemInstruction(); + if (this._geminiClient?.isInitialized()) { + await this._geminiClient.setTools(); + this._geminiClient.updateSystemInstruction(); } } @@ -2045,8 +2064,8 @@ export class Config implements McpContext { (currentMode === ApprovalMode.YOLO || mode === ApprovalMode.YOLO); if (isPlanModeTransition || isYoloModeTransition) { - if (this.geminiClient?.isInitialized()) { - this.geminiClient.setTools().catch((err) => { + if (this._geminiClient?.isInitialized()) { + this._geminiClient.setTools().catch((err) => { debugLogger.error('Failed to update tools', err); }); } @@ -2142,6 +2161,7 @@ export class Config implements McpContext { return this.telemetrySettings.useCliAuth ?? false; } + /** @deprecated Use geminiClient getter */ getGeminiClient(): GeminiClient { return this.geminiClient; } @@ -2577,7 +2597,7 @@ export class Config implements McpContext { if (this.getSkillManager().getSkills().length > 0) { this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); this.getToolRegistry().registerTool( - new ActivateSkillTool(this, this.messageBus), + new ActivateSkillTool(this, this._messageBus), ); } else { this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); @@ -2703,6 +2723,7 @@ export class Config implements McpContext { return this.fileExclusions; } + /** @deprecated Use messageBus getter */ getMessageBus(): MessageBus { return this.messageBus; } @@ -2760,7 +2781,7 @@ export class Config implements McpContext { } async createToolRegistry(): Promise { - const registry = new ToolRegistry(this, this.messageBus); + const registry = new ToolRegistry(this, this._messageBus); // helper to create & register core tools that are enabled const maybeRegister = ( @@ -2790,10 +2811,10 @@ export class Config implements McpContext { }; maybeRegister(LSTool, () => - registry.registerTool(new LSTool(this, this.messageBus)), + registry.registerTool(new LSTool(this, this._messageBus)), ); maybeRegister(ReadFileTool, () => - registry.registerTool(new ReadFileTool(this, this.messageBus)), + registry.registerTool(new ReadFileTool(this, this._messageBus)), ); if (this.getUseRipgrep()) { @@ -2806,81 +2827,85 @@ export class Config implements McpContext { } if (useRipgrep) { maybeRegister(RipGrepTool, () => - registry.registerTool(new RipGrepTool(this, this.messageBus)), + registry.registerTool(new RipGrepTool(this, this._messageBus)), ); } else { logRipgrepFallback(this, new RipgrepFallbackEvent(errorString)); maybeRegister(GrepTool, () => - registry.registerTool(new GrepTool(this, this.messageBus)), + registry.registerTool(new GrepTool(this, this._messageBus)), ); } } else { maybeRegister(GrepTool, () => - registry.registerTool(new GrepTool(this, this.messageBus)), + registry.registerTool(new GrepTool(this, this._messageBus)), ); } maybeRegister(GlobTool, () => - registry.registerTool(new GlobTool(this, this.messageBus)), + registry.registerTool(new GlobTool(this, this._messageBus)), ); maybeRegister(ActivateSkillTool, () => - registry.registerTool(new ActivateSkillTool(this, this.messageBus)), + registry.registerTool(new ActivateSkillTool(this, this._messageBus)), ); maybeRegister(EditTool, () => - registry.registerTool(new EditTool(this, this.messageBus)), + registry.registerTool(new EditTool(this, this._messageBus)), ); maybeRegister(WriteFileTool, () => - registry.registerTool(new WriteFileTool(this, this.messageBus)), + registry.registerTool(new WriteFileTool(this, this._messageBus)), ); maybeRegister(WebFetchTool, () => - registry.registerTool(new WebFetchTool(this, this.messageBus)), + registry.registerTool(new WebFetchTool(this, this._messageBus)), ); maybeRegister(ShellTool, () => - registry.registerTool(new ShellTool(this, this.messageBus)), + registry.registerTool(new ShellTool(this, this._messageBus)), ); maybeRegister(MemoryTool, () => - registry.registerTool(new MemoryTool(this.messageBus)), + registry.registerTool(new MemoryTool(this._messageBus)), ); maybeRegister(WebSearchTool, () => - registry.registerTool(new WebSearchTool(this, this.messageBus)), + registry.registerTool(new WebSearchTool(this, this._messageBus)), ); maybeRegister(AskUserTool, () => - registry.registerTool(new AskUserTool(this.messageBus)), + registry.registerTool(new AskUserTool(this._messageBus)), ); if (this.getUseWriteTodos()) { maybeRegister(WriteTodosTool, () => - registry.registerTool(new WriteTodosTool(this.messageBus)), + registry.registerTool(new WriteTodosTool(this._messageBus)), ); } if (this.isPlanEnabled()) { maybeRegister(ExitPlanModeTool, () => - registry.registerTool(new ExitPlanModeTool(this, this.messageBus)), + registry.registerTool(new ExitPlanModeTool(this, this._messageBus)), ); maybeRegister(EnterPlanModeTool, () => - registry.registerTool(new EnterPlanModeTool(this, this.messageBus)), + registry.registerTool(new EnterPlanModeTool(this, this._messageBus)), ); } if (this.isTrackerEnabled()) { maybeRegister(TrackerCreateTaskTool, () => - registry.registerTool(new TrackerCreateTaskTool(this, this.messageBus)), + registry.registerTool( + new TrackerCreateTaskTool(this, this._messageBus), + ), ); maybeRegister(TrackerUpdateTaskTool, () => - registry.registerTool(new TrackerUpdateTaskTool(this, this.messageBus)), + registry.registerTool( + new TrackerUpdateTaskTool(this, this._messageBus), + ), ); maybeRegister(TrackerGetTaskTool, () => - registry.registerTool(new TrackerGetTaskTool(this, this.messageBus)), + registry.registerTool(new TrackerGetTaskTool(this, this._messageBus)), ); maybeRegister(TrackerListTasksTool, () => - registry.registerTool(new TrackerListTasksTool(this, this.messageBus)), + registry.registerTool(new TrackerListTasksTool(this, this._messageBus)), ); maybeRegister(TrackerAddDependencyTool, () => registry.registerTool( - new TrackerAddDependencyTool(this, this.messageBus), + new TrackerAddDependencyTool(this, this._messageBus), ), ); maybeRegister(TrackerVisualizeTool, () => - registry.registerTool(new TrackerVisualizeTool(this, this.messageBus)), + registry.registerTool(new TrackerVisualizeTool(this, this._messageBus)), ); } @@ -3007,8 +3032,8 @@ export class Config implements McpContext { } private onAgentsRefreshed = async () => { - if (this.toolRegistry) { - this.registerSubAgentTools(this.toolRegistry); + if (this._toolRegistry) { + this.registerSubAgentTools(this._toolRegistry); } // Propagate updates to the active chat session const client = this.getGeminiClient(); @@ -3029,7 +3054,7 @@ export class Config implements McpContext { this.logCurrentModeDuration(this.getApprovalMode()); coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed); this.agentRegistry?.dispose(); - this.geminiClient?.dispose(); + this._geminiClient?.dispose(); if (this.mcpClientManager) { await this.mcpClientManager.stop(); } diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index fcddc05a44..a2f98dde98 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -290,6 +290,8 @@ function createMockConfig(overrides: Partial = {}): Config { const finalConfig = { ...baseConfig, ...overrides } as Config; + (finalConfig as unknown as { config: Config }).config = finalConfig; + // Patch the policy engine to use the final config if not overridden if (!overrides.getPolicyEngine) { finalConfig.getPolicyEngine = () => diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 23473e199d..15b7f1932b 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -133,7 +133,7 @@ export class CoreToolScheduler { this.onAllToolCallsComplete = options.onAllToolCallsComplete; this.onToolCallsUpdate = options.onToolCallsUpdate; this.getPreferredEditor = options.getPreferredEditor; - this.toolExecutor = new ToolExecutor(this.config); + this.toolExecutor = new ToolExecutor(this.config, this.config); this.toolModifier = new ToolModificationHandler(); // Subscribe to message bus for ASK_USER policy decisions diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 64b27493a0..5dfd74ad61 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -6,6 +6,7 @@ // Export config export * from './config/config.js'; +export * from './config/agent-loop-context.js'; export * from './config/memory.js'; export * from './config/defaultModelConfigs.js'; export * from './config/models.js'; diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts index 05f5b08a2f..9320893bd6 100644 --- a/packages/core/src/scheduler/policy.test.ts +++ b/packages/core/src/scheduler/policy.test.ts @@ -49,6 +49,9 @@ describe('policy.ts', () => { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; + const toolCall = { request: { name: 'test-tool', args: {} }, tool: { name: 'test-tool' }, @@ -72,6 +75,9 @@ describe('policy.ts', () => { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; + const mcpTool = Object.create(DiscoveredMCPTool.prototype); mcpTool.serverName = 'my-server'; mcpTool._toolAnnotations = { readOnlyHint: true }; @@ -99,6 +105,9 @@ describe('policy.ts', () => { isInteractive: vi.fn().mockReturnValue(false), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; + const toolCall = { request: { name: 'test-tool', args: {} }, tool: { name: 'test-tool' }, @@ -118,6 +127,9 @@ describe('policy.ts', () => { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; + const toolCall = { request: { name: 'test-tool', args: {} }, tool: { name: 'test-tool' }, @@ -137,6 +149,9 @@ describe('policy.ts', () => { isInteractive: vi.fn().mockReturnValue(true), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; + const toolCall = { request: { name: 'test-tool', args: {} }, tool: { name: 'test-tool' }, @@ -152,6 +167,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -175,6 +193,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -200,6 +221,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -225,6 +249,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -256,6 +283,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -290,6 +320,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -308,6 +341,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -325,6 +361,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -344,6 +383,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -378,6 +420,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -410,6 +455,9 @@ describe('policy.ts', () => { const mockConfig = { setApprovalMode: vi.fn(), } as unknown as Mocked; + + (mockConfig as unknown as { config: Config }).config = + mockConfig as Config; const mockMessageBus = { publish: vi.fn(), } as unknown as Mocked; @@ -447,6 +495,8 @@ describe('policy.ts', () => { getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Config; + (mockConfig as unknown as { config: Config }).config = mockConfig; + const { errorMessage, errorType } = getPolicyDenialError(mockConfig); expect(errorMessage).toBe('Tool execution denied by policy.'); @@ -457,6 +507,8 @@ describe('policy.ts', () => { const mockConfig = { getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Config; + + (mockConfig as unknown as { config: Config }).config = mockConfig; const rule = { decision: PolicyDecision.DENY, denyMessage: 'Custom Deny', @@ -517,7 +569,8 @@ describe('Plan Mode Denial Consistency', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + toolRegistry: mockToolRegistry, + getToolRegistry: () => mockToolRegistry, getMessageBus: vi.fn().mockReturnValue(mockMessageBus), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(false), @@ -525,6 +578,7 @@ describe('Plan Mode Denial Consistency', () => { setApprovalMode: vi.fn(), getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = mockConfig as Config; }); afterEach(() => { diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index ee5438c319..4d40101140 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -169,13 +169,15 @@ describe('Scheduler (Orchestrator)', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + toolRegistry: mockToolRegistry, isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = mockConfig as Config; + mockMessageBus = { publish: vi.fn(), subscribe: vi.fn(), @@ -1320,6 +1322,8 @@ describe('Scheduler MCP Progress', () => { getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = mockConfig as Config; + mockMessageBus = { publish: vi.fn(), subscribe: vi.fn(), diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 38e001ea90..613e23b2d6 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -5,6 +5,7 @@ */ import type { Config } from '../config/config.js'; +import type { AgentLoopContext } from '../config/agent-loop-context.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { SchedulerStateManager } from './state-manager.js'; import { resolveConfirmation } from './confirmation.js'; @@ -57,7 +58,7 @@ interface SchedulerQueueItem { export interface SchedulerOptions { config: Config; - messageBus: MessageBus; + messageBus?: MessageBus; getPreferredEditor: () => EditorType | undefined; schedulerId: string; parentCallId?: string; @@ -97,6 +98,7 @@ export class Scheduler { private readonly executor: ToolExecutor; private readonly modifier: ToolModificationHandler; private readonly config: Config; + private readonly context: AgentLoopContext; private readonly messageBus: MessageBus; private readonly getPreferredEditor: () => EditorType | undefined; private readonly schedulerId: string; @@ -109,7 +111,8 @@ export class Scheduler { constructor(options: SchedulerOptions) { this.config = options.config; - this.messageBus = options.messageBus; + this.context = options.config; + this.messageBus = options.messageBus ?? this.context.messageBus; this.getPreferredEditor = options.getPreferredEditor; this.schedulerId = options.schedulerId; this.parentCallId = options.parentCallId; @@ -119,7 +122,7 @@ export class Scheduler { this.schedulerId, (call) => logToolCall(this.config, new ToolCallEvent(call)), ); - this.executor = new ToolExecutor(this.config); + this.executor = new ToolExecutor(this.config, this.context); this.modifier = new ToolModificationHandler(); this.setupMessageBusListener(this.messageBus); @@ -294,7 +297,7 @@ export class Scheduler { const currentApprovalMode = this.config.getApprovalMode(); try { - const toolRegistry = this.config.getToolRegistry(); + const toolRegistry = this.context.toolRegistry; const newCalls: ToolCall[] = requests.map((request) => { const enrichedRequest: ToolCallRequestInfo = { ...request, @@ -697,7 +700,7 @@ export class Scheduler { const originalRequestName = result.request.originalRequestName || result.request.name; - const newTool = this.config.getToolRegistry().getTool(tailRequest.name); + const newTool = this.context.toolRegistry.getTool(tailRequest.name); const newRequest: ToolCallRequestInfo = { callId: originalCallId, @@ -713,7 +716,7 @@ export class Scheduler { // Enqueue an errored tool call const errorCall = this._createToolNotFoundErroredToolCall( newRequest, - this.config.getToolRegistry().getAllToolNames(), + this.context.toolRegistry.getAllToolNames(), ); this.state.replaceActiveCallWithTailCall(callId, errorCall); } else { diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index 56e6e26243..5342d3ac20 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -211,13 +211,15 @@ describe('Scheduler Parallel Execution', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + toolRegistry: mockToolRegistry, isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Mocked; + (mockConfig as unknown as { config: Config }).config = mockConfig as Config; + mockMessageBus = { publish: vi.fn(), subscribe: vi.fn(), diff --git a/packages/core/src/scheduler/scheduler_waiting_callback.test.ts b/packages/core/src/scheduler/scheduler_waiting_callback.test.ts index e878a80669..03b754fc86 100644 --- a/packages/core/src/scheduler/scheduler_waiting_callback.test.ts +++ b/packages/core/src/scheduler/scheduler_waiting_callback.test.ts @@ -36,7 +36,7 @@ describe('Scheduler waiting callback', () => { mockTool = new MockTool({ name: 'test_tool' }); toolRegistry = new ToolRegistry(mockConfig, messageBus); - vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry); + vi.spyOn(mockConfig, 'toolRegistry', 'get').mockReturnValue(toolRegistry); toolRegistry.registerTool(mockTool); vi.mocked(checkPolicy).mockResolvedValue({ diff --git a/packages/core/src/scheduler/tool-executor.test.ts b/packages/core/src/scheduler/tool-executor.test.ts index e744738341..a193c8ae69 100644 --- a/packages/core/src/scheduler/tool-executor.test.ts +++ b/packages/core/src/scheduler/tool-executor.test.ts @@ -64,7 +64,7 @@ describe('ToolExecutor', () => { beforeEach(() => { // Use the standard fake config factory config = makeFakeConfig(); - executor = new ToolExecutor(config); + executor = new ToolExecutor(config, config); // Reset mocks vi.resetAllMocks(); diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 8269f1fc41..e5491630d2 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -13,6 +13,7 @@ import { type ToolCallResponseInfo, type ToolResult, type Config, + type AgentLoopContext, type ToolLiveOutput, } from '../index.js'; import { SHELL_TOOL_NAME } from '../tools/tool-names.js'; @@ -49,7 +50,10 @@ export interface ToolExecutionContext { } export class ToolExecutor { - constructor(private readonly config: Config) {} + constructor( + private readonly config: Config, + private readonly context: AgentLoopContext, + ) {} async execute(context: ToolExecutionContext): Promise { const { call, signal, outputUpdateHandler, onUpdateToolCall } = context; @@ -202,7 +206,7 @@ export class ToolExecutor { toolName, callId, this.config.storage.getProjectTempDir(), - this.config.getSessionId(), + this.context.promptId, ); outputFile = savedPath; const truncatedContent = formatTruncatedToolOutput( @@ -241,7 +245,7 @@ export class ToolExecutor { toolName, callId, this.config.storage.getProjectTempDir(), - this.config.getSessionId(), + this.context.promptId, ); outputFile = savedPath; const truncatedText = formatTruncatedToolOutput(