mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 04:24:51 -07:00
feat(core): Introduce AgentLoopContext. (#21198)
This commit is contained in:
@@ -30,7 +30,7 @@ describe('agent-scheduler', () => {
|
||||
} as unknown as Mocked<ToolRegistry>;
|
||||
mockConfig = {
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
toolRegistry: mockToolRegistry,
|
||||
} as unknown as Mocked<Config>;
|
||||
});
|
||||
|
||||
@@ -69,6 +69,6 @@ describe('agent-scheduler', () => {
|
||||
|
||||
// Verify that the scheduler's config has the overridden tool registry
|
||||
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
|
||||
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
|
||||
expect(schedulerConfig.toolRegistry).toBe(mockToolRegistry);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
|
||||
/**
|
||||
* AgentLoopContext represents the execution-scoped view of the world for a single
|
||||
* agent turn or sub-agent loop.
|
||||
*/
|
||||
export interface AgentLoopContext {
|
||||
/** The unique ID for the current user turn or agent thought loop. */
|
||||
readonly promptId: string;
|
||||
|
||||
/** The registry of tools available to the agent in this context. */
|
||||
readonly toolRegistry: ToolRegistry;
|
||||
|
||||
/** The bus for user confirmations and messages in this context. */
|
||||
readonly messageBus: MessageBus;
|
||||
|
||||
/** The client used to communicate with the LLM in this context. */
|
||||
readonly geminiClient: GeminiClient;
|
||||
}
|
||||
@@ -96,6 +96,7 @@ import type {
|
||||
import { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
import { OutputFormat } from '../output/types.js';
|
||||
//import { type AgentLoopContext } from './agent-loop-context.js';
|
||||
import {
|
||||
ModelConfigService,
|
||||
type ModelConfig,
|
||||
@@ -154,6 +155,7 @@ import { CheckerRunner } from '../safety/checker-runner.js';
|
||||
import { ContextBuilder } from '../safety/context-builder.js';
|
||||
import { CheckerRegistry } from '../safety/registry.js';
|
||||
import { ConsecaSafetyChecker } from '../safety/conseca/conseca.js';
|
||||
import type { AgentLoopContext } from './agent-loop-context.js';
|
||||
|
||||
export interface AccessibilitySettings {
|
||||
/** @deprecated Use ui.loadingPhrases instead. */
|
||||
@@ -598,8 +600,8 @@ export interface ConfigParameters {
|
||||
};
|
||||
}
|
||||
|
||||
export class Config implements McpContext {
|
||||
private toolRegistry!: ToolRegistry;
|
||||
export class Config implements McpContext, AgentLoopContext {
|
||||
private _toolRegistry!: ToolRegistry;
|
||||
private mcpClientManager?: McpClientManager;
|
||||
private allowedMcpServers: string[];
|
||||
private blockedMcpServers: string[];
|
||||
@@ -611,7 +613,7 @@ export class Config implements McpContext {
|
||||
private agentRegistry!: AgentRegistry;
|
||||
private readonly acknowledgedAgentsService: AcknowledgedAgentsService;
|
||||
private skillManager!: SkillManager;
|
||||
private sessionId: string;
|
||||
private _sessionId: string;
|
||||
private clientVersion: string;
|
||||
private fileSystemService: FileSystemService;
|
||||
private trackerService?: TrackerService;
|
||||
@@ -645,7 +647,7 @@ export class Config implements McpContext {
|
||||
private readonly accessibility: AccessibilitySettings;
|
||||
private readonly telemetrySettings: TelemetrySettings;
|
||||
private readonly usageStatisticsEnabled: boolean;
|
||||
private geminiClient!: GeminiClient;
|
||||
private _geminiClient!: GeminiClient;
|
||||
private baseLlmClient!: BaseLlmClient;
|
||||
private localLiteRtLmClient?: LocalLiteRtLmClient;
|
||||
private modelRouterService: ModelRouterService;
|
||||
@@ -740,7 +742,7 @@ export class Config implements McpContext {
|
||||
private readonly fileExclusions: FileExclusions;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
private readonly useWriteTodos: boolean;
|
||||
private readonly messageBus: MessageBus;
|
||||
private readonly _messageBus: MessageBus;
|
||||
private readonly policyEngine: PolicyEngine;
|
||||
private policyUpdateConfirmationRequest:
|
||||
| PolicyUpdateConfirmationRequest
|
||||
@@ -806,7 +808,7 @@ export class Config implements McpContext {
|
||||
private approvedPlanPath: string | undefined;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
this._sessionId = params.sessionId;
|
||||
this.clientVersion = params.clientVersion ?? 'unknown';
|
||||
this.approvedPlanPath = undefined;
|
||||
this.embeddingModel =
|
||||
@@ -961,7 +963,7 @@ export class Config implements McpContext {
|
||||
(params.shellToolInactivityTimeout ?? 300) * 1000; // 5 minutes
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
|
||||
this.storage = new Storage(this.targetDir, this.sessionId);
|
||||
this.storage = new Storage(this.targetDir, this._sessionId);
|
||||
this.storage.setCustomPlansDir(params.planSettings?.directory);
|
||||
|
||||
this.fakeResponses = params.fakeResponses;
|
||||
@@ -997,7 +999,7 @@ export class Config implements McpContext {
|
||||
ConsecaSafetyChecker.getInstance().setConfig(this);
|
||||
}
|
||||
|
||||
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
|
||||
this._messageBus = new MessageBus(this.policyEngine, this.debugMode);
|
||||
this.acknowledgedAgentsService = new AcknowledgedAgentsService();
|
||||
this.skillManager = new SkillManager();
|
||||
this.outputSettings = {
|
||||
@@ -1057,7 +1059,7 @@ export class Config implements McpContext {
|
||||
);
|
||||
}
|
||||
}
|
||||
this.geminiClient = new GeminiClient(this);
|
||||
this._geminiClient = new GeminiClient(this);
|
||||
this.modelRouterService = new ModelRouterService(this);
|
||||
|
||||
// HACK: The settings loading logic doesn't currently merge the default
|
||||
@@ -1142,11 +1144,11 @@ export class Config implements McpContext {
|
||||
|
||||
coreEvents.on(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
|
||||
|
||||
this.toolRegistry = await this.createToolRegistry();
|
||||
this._toolRegistry = await this.createToolRegistry();
|
||||
discoverToolsHandle?.end();
|
||||
this.mcpClientManager = new McpClientManager(
|
||||
this.clientVersion,
|
||||
this.toolRegistry,
|
||||
this._toolRegistry,
|
||||
this,
|
||||
this.eventEmitter,
|
||||
);
|
||||
@@ -1181,7 +1183,7 @@ export class Config implements McpContext {
|
||||
if (this.getSkillManager().getSkills().length > 0) {
|
||||
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
|
||||
this.getToolRegistry().registerTool(
|
||||
new ActivateSkillTool(this, this.messageBus),
|
||||
new ActivateSkillTool(this, this._messageBus),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1198,7 +1200,7 @@ export class Config implements McpContext {
|
||||
await this.contextManager.refresh();
|
||||
}
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
await this._geminiClient.initialize();
|
||||
this.initialized = true;
|
||||
}
|
||||
|
||||
@@ -1222,7 +1224,7 @@ export class Config implements McpContext {
|
||||
authMethod !== AuthType.USE_GEMINI
|
||||
) {
|
||||
// Restore the conversation history to the new client
|
||||
this.geminiClient.stripThoughtsFromHistory();
|
||||
this._geminiClient.stripThoughtsFromHistory();
|
||||
}
|
||||
|
||||
// Reset availability status when switching auth (e.g. from limited key to OAuth)
|
||||
@@ -1343,12 +1345,28 @@ export class Config implements McpContext {
|
||||
return this.localLiteRtLmClient;
|
||||
}
|
||||
|
||||
get promptId(): string {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
get toolRegistry(): ToolRegistry {
|
||||
return this._toolRegistry;
|
||||
}
|
||||
|
||||
get messageBus(): MessageBus {
|
||||
return this._messageBus;
|
||||
}
|
||||
|
||||
get geminiClient(): GeminiClient {
|
||||
return this._geminiClient;
|
||||
}
|
||||
|
||||
getSessionId(): string {
|
||||
return this.sessionId;
|
||||
return this.promptId;
|
||||
}
|
||||
|
||||
setSessionId(sessionId: string): void {
|
||||
this.sessionId = sessionId;
|
||||
this._sessionId = sessionId;
|
||||
}
|
||||
|
||||
setTerminalBackground(terminalBackground: string | undefined): void {
|
||||
@@ -1613,6 +1631,7 @@ export class Config implements McpContext {
|
||||
return this.acknowledgedAgentsService;
|
||||
}
|
||||
|
||||
/** @deprecated Use toolRegistry getter */
|
||||
getToolRegistry(): ToolRegistry {
|
||||
return this.toolRegistry;
|
||||
}
|
||||
@@ -1889,9 +1908,9 @@ export class Config implements McpContext {
|
||||
);
|
||||
await refreshServerHierarchicalMemory(this);
|
||||
}
|
||||
if (this.geminiClient?.isInitialized()) {
|
||||
await this.geminiClient.setTools();
|
||||
this.geminiClient.updateSystemInstruction();
|
||||
if (this._geminiClient?.isInitialized()) {
|
||||
await this._geminiClient.setTools();
|
||||
this._geminiClient.updateSystemInstruction();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2045,8 +2064,8 @@ export class Config implements McpContext {
|
||||
(currentMode === ApprovalMode.YOLO || mode === ApprovalMode.YOLO);
|
||||
|
||||
if (isPlanModeTransition || isYoloModeTransition) {
|
||||
if (this.geminiClient?.isInitialized()) {
|
||||
this.geminiClient.setTools().catch((err) => {
|
||||
if (this._geminiClient?.isInitialized()) {
|
||||
this._geminiClient.setTools().catch((err) => {
|
||||
debugLogger.error('Failed to update tools', err);
|
||||
});
|
||||
}
|
||||
@@ -2142,6 +2161,7 @@ export class Config implements McpContext {
|
||||
return this.telemetrySettings.useCliAuth ?? false;
|
||||
}
|
||||
|
||||
/** @deprecated Use geminiClient getter */
|
||||
getGeminiClient(): GeminiClient {
|
||||
return this.geminiClient;
|
||||
}
|
||||
@@ -2577,7 +2597,7 @@ export class Config implements McpContext {
|
||||
if (this.getSkillManager().getSkills().length > 0) {
|
||||
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
|
||||
this.getToolRegistry().registerTool(
|
||||
new ActivateSkillTool(this, this.messageBus),
|
||||
new ActivateSkillTool(this, this._messageBus),
|
||||
);
|
||||
} else {
|
||||
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
|
||||
@@ -2703,6 +2723,7 @@ export class Config implements McpContext {
|
||||
return this.fileExclusions;
|
||||
}
|
||||
|
||||
/** @deprecated Use messageBus getter */
|
||||
getMessageBus(): MessageBus {
|
||||
return this.messageBus;
|
||||
}
|
||||
@@ -2760,7 +2781,7 @@ export class Config implements McpContext {
|
||||
}
|
||||
|
||||
async createToolRegistry(): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(this, this.messageBus);
|
||||
const registry = new ToolRegistry(this, this._messageBus);
|
||||
|
||||
// helper to create & register core tools that are enabled
|
||||
const maybeRegister = (
|
||||
@@ -2790,10 +2811,10 @@ export class Config implements McpContext {
|
||||
};
|
||||
|
||||
maybeRegister(LSTool, () =>
|
||||
registry.registerTool(new LSTool(this, this.messageBus)),
|
||||
registry.registerTool(new LSTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(ReadFileTool, () =>
|
||||
registry.registerTool(new ReadFileTool(this, this.messageBus)),
|
||||
registry.registerTool(new ReadFileTool(this, this._messageBus)),
|
||||
);
|
||||
|
||||
if (this.getUseRipgrep()) {
|
||||
@@ -2806,81 +2827,85 @@ export class Config implements McpContext {
|
||||
}
|
||||
if (useRipgrep) {
|
||||
maybeRegister(RipGrepTool, () =>
|
||||
registry.registerTool(new RipGrepTool(this, this.messageBus)),
|
||||
registry.registerTool(new RipGrepTool(this, this._messageBus)),
|
||||
);
|
||||
} else {
|
||||
logRipgrepFallback(this, new RipgrepFallbackEvent(errorString));
|
||||
maybeRegister(GrepTool, () =>
|
||||
registry.registerTool(new GrepTool(this, this.messageBus)),
|
||||
registry.registerTool(new GrepTool(this, this._messageBus)),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
maybeRegister(GrepTool, () =>
|
||||
registry.registerTool(new GrepTool(this, this.messageBus)),
|
||||
registry.registerTool(new GrepTool(this, this._messageBus)),
|
||||
);
|
||||
}
|
||||
|
||||
maybeRegister(GlobTool, () =>
|
||||
registry.registerTool(new GlobTool(this, this.messageBus)),
|
||||
registry.registerTool(new GlobTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(ActivateSkillTool, () =>
|
||||
registry.registerTool(new ActivateSkillTool(this, this.messageBus)),
|
||||
registry.registerTool(new ActivateSkillTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(EditTool, () =>
|
||||
registry.registerTool(new EditTool(this, this.messageBus)),
|
||||
registry.registerTool(new EditTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(WriteFileTool, () =>
|
||||
registry.registerTool(new WriteFileTool(this, this.messageBus)),
|
||||
registry.registerTool(new WriteFileTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(WebFetchTool, () =>
|
||||
registry.registerTool(new WebFetchTool(this, this.messageBus)),
|
||||
registry.registerTool(new WebFetchTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(ShellTool, () =>
|
||||
registry.registerTool(new ShellTool(this, this.messageBus)),
|
||||
registry.registerTool(new ShellTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(MemoryTool, () =>
|
||||
registry.registerTool(new MemoryTool(this.messageBus)),
|
||||
registry.registerTool(new MemoryTool(this._messageBus)),
|
||||
);
|
||||
maybeRegister(WebSearchTool, () =>
|
||||
registry.registerTool(new WebSearchTool(this, this.messageBus)),
|
||||
registry.registerTool(new WebSearchTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(AskUserTool, () =>
|
||||
registry.registerTool(new AskUserTool(this.messageBus)),
|
||||
registry.registerTool(new AskUserTool(this._messageBus)),
|
||||
);
|
||||
if (this.getUseWriteTodos()) {
|
||||
maybeRegister(WriteTodosTool, () =>
|
||||
registry.registerTool(new WriteTodosTool(this.messageBus)),
|
||||
registry.registerTool(new WriteTodosTool(this._messageBus)),
|
||||
);
|
||||
}
|
||||
if (this.isPlanEnabled()) {
|
||||
maybeRegister(ExitPlanModeTool, () =>
|
||||
registry.registerTool(new ExitPlanModeTool(this, this.messageBus)),
|
||||
registry.registerTool(new ExitPlanModeTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(EnterPlanModeTool, () =>
|
||||
registry.registerTool(new EnterPlanModeTool(this, this.messageBus)),
|
||||
registry.registerTool(new EnterPlanModeTool(this, this._messageBus)),
|
||||
);
|
||||
}
|
||||
|
||||
if (this.isTrackerEnabled()) {
|
||||
maybeRegister(TrackerCreateTaskTool, () =>
|
||||
registry.registerTool(new TrackerCreateTaskTool(this, this.messageBus)),
|
||||
registry.registerTool(
|
||||
new TrackerCreateTaskTool(this, this._messageBus),
|
||||
),
|
||||
);
|
||||
maybeRegister(TrackerUpdateTaskTool, () =>
|
||||
registry.registerTool(new TrackerUpdateTaskTool(this, this.messageBus)),
|
||||
registry.registerTool(
|
||||
new TrackerUpdateTaskTool(this, this._messageBus),
|
||||
),
|
||||
);
|
||||
maybeRegister(TrackerGetTaskTool, () =>
|
||||
registry.registerTool(new TrackerGetTaskTool(this, this.messageBus)),
|
||||
registry.registerTool(new TrackerGetTaskTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(TrackerListTasksTool, () =>
|
||||
registry.registerTool(new TrackerListTasksTool(this, this.messageBus)),
|
||||
registry.registerTool(new TrackerListTasksTool(this, this._messageBus)),
|
||||
);
|
||||
maybeRegister(TrackerAddDependencyTool, () =>
|
||||
registry.registerTool(
|
||||
new TrackerAddDependencyTool(this, this.messageBus),
|
||||
new TrackerAddDependencyTool(this, this._messageBus),
|
||||
),
|
||||
);
|
||||
maybeRegister(TrackerVisualizeTool, () =>
|
||||
registry.registerTool(new TrackerVisualizeTool(this, this.messageBus)),
|
||||
registry.registerTool(new TrackerVisualizeTool(this, this._messageBus)),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3007,8 +3032,8 @@ export class Config implements McpContext {
|
||||
}
|
||||
|
||||
private onAgentsRefreshed = async () => {
|
||||
if (this.toolRegistry) {
|
||||
this.registerSubAgentTools(this.toolRegistry);
|
||||
if (this._toolRegistry) {
|
||||
this.registerSubAgentTools(this._toolRegistry);
|
||||
}
|
||||
// Propagate updates to the active chat session
|
||||
const client = this.getGeminiClient();
|
||||
@@ -3029,7 +3054,7 @@ export class Config implements McpContext {
|
||||
this.logCurrentModeDuration(this.getApprovalMode());
|
||||
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
|
||||
this.agentRegistry?.dispose();
|
||||
this.geminiClient?.dispose();
|
||||
this._geminiClient?.dispose();
|
||||
if (this.mcpClientManager) {
|
||||
await this.mcpClientManager.stop();
|
||||
}
|
||||
|
||||
@@ -290,6 +290,8 @@ function createMockConfig(overrides: Partial<Config> = {}): Config {
|
||||
|
||||
const finalConfig = { ...baseConfig, ...overrides } as Config;
|
||||
|
||||
(finalConfig as unknown as { config: Config }).config = finalConfig;
|
||||
|
||||
// Patch the policy engine to use the final config if not overridden
|
||||
if (!overrides.getPolicyEngine) {
|
||||
finalConfig.getPolicyEngine = () =>
|
||||
|
||||
@@ -133,7 +133,7 @@ export class CoreToolScheduler {
|
||||
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
|
||||
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
||||
this.getPreferredEditor = options.getPreferredEditor;
|
||||
this.toolExecutor = new ToolExecutor(this.config);
|
||||
this.toolExecutor = new ToolExecutor(this.config, this.config);
|
||||
this.toolModifier = new ToolModificationHandler();
|
||||
|
||||
// Subscribe to message bus for ASK_USER policy decisions
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
// Export config
|
||||
export * from './config/config.js';
|
||||
export * from './config/agent-loop-context.js';
|
||||
export * from './config/memory.js';
|
||||
export * from './config/defaultModelConfigs.js';
|
||||
export * from './config/models.js';
|
||||
|
||||
@@ -49,6 +49,9 @@ describe('policy.ts', () => {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
|
||||
const toolCall = {
|
||||
request: { name: 'test-tool', args: {} },
|
||||
tool: { name: 'test-tool' },
|
||||
@@ -72,6 +75,9 @@ describe('policy.ts', () => {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
|
||||
const mcpTool = Object.create(DiscoveredMCPTool.prototype);
|
||||
mcpTool.serverName = 'my-server';
|
||||
mcpTool._toolAnnotations = { readOnlyHint: true };
|
||||
@@ -99,6 +105,9 @@ describe('policy.ts', () => {
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
|
||||
const toolCall = {
|
||||
request: { name: 'test-tool', args: {} },
|
||||
tool: { name: 'test-tool' },
|
||||
@@ -118,6 +127,9 @@ describe('policy.ts', () => {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
|
||||
const toolCall = {
|
||||
request: { name: 'test-tool', args: {} },
|
||||
tool: { name: 'test-tool' },
|
||||
@@ -137,6 +149,9 @@ describe('policy.ts', () => {
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
|
||||
const toolCall = {
|
||||
request: { name: 'test-tool', args: {} },
|
||||
tool: { name: 'test-tool' },
|
||||
@@ -152,6 +167,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -175,6 +193,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -200,6 +221,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -225,6 +249,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -256,6 +283,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -290,6 +320,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -308,6 +341,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -325,6 +361,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -344,6 +383,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -378,6 +420,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -410,6 +455,9 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
setApprovalMode: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config =
|
||||
mockConfig as Config;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
@@ -447,6 +495,8 @@ describe('policy.ts', () => {
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Config;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig;
|
||||
|
||||
const { errorMessage, errorType } = getPolicyDenialError(mockConfig);
|
||||
|
||||
expect(errorMessage).toBe('Tool execution denied by policy.');
|
||||
@@ -457,6 +507,8 @@ describe('policy.ts', () => {
|
||||
const mockConfig = {
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Config;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig;
|
||||
const rule = {
|
||||
decision: PolicyDecision.DENY,
|
||||
denyMessage: 'Custom Deny',
|
||||
@@ -517,7 +569,8 @@ describe('Plan Mode Denial Consistency', () => {
|
||||
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
@@ -525,6 +578,7 @@ describe('Plan Mode Denial Consistency', () => {
|
||||
setApprovalMode: vi.fn(),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Mocked<Config>;
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -169,13 +169,15 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
toolRegistry: mockToolRegistry,
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
|
||||
|
||||
mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
@@ -1320,6 +1322,8 @@ describe('Scheduler MCP Progress', () => {
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
|
||||
|
||||
mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { SchedulerStateManager } from './state-manager.js';
|
||||
import { resolveConfirmation } from './confirmation.js';
|
||||
@@ -57,7 +58,7 @@ interface SchedulerQueueItem {
|
||||
|
||||
export interface SchedulerOptions {
|
||||
config: Config;
|
||||
messageBus: MessageBus;
|
||||
messageBus?: MessageBus;
|
||||
getPreferredEditor: () => EditorType | undefined;
|
||||
schedulerId: string;
|
||||
parentCallId?: string;
|
||||
@@ -97,6 +98,7 @@ export class Scheduler {
|
||||
private readonly executor: ToolExecutor;
|
||||
private readonly modifier: ToolModificationHandler;
|
||||
private readonly config: Config;
|
||||
private readonly context: AgentLoopContext;
|
||||
private readonly messageBus: MessageBus;
|
||||
private readonly getPreferredEditor: () => EditorType | undefined;
|
||||
private readonly schedulerId: string;
|
||||
@@ -109,7 +111,8 @@ export class Scheduler {
|
||||
|
||||
constructor(options: SchedulerOptions) {
|
||||
this.config = options.config;
|
||||
this.messageBus = options.messageBus;
|
||||
this.context = options.config;
|
||||
this.messageBus = options.messageBus ?? this.context.messageBus;
|
||||
this.getPreferredEditor = options.getPreferredEditor;
|
||||
this.schedulerId = options.schedulerId;
|
||||
this.parentCallId = options.parentCallId;
|
||||
@@ -119,7 +122,7 @@ export class Scheduler {
|
||||
this.schedulerId,
|
||||
(call) => logToolCall(this.config, new ToolCallEvent(call)),
|
||||
);
|
||||
this.executor = new ToolExecutor(this.config);
|
||||
this.executor = new ToolExecutor(this.config, this.context);
|
||||
this.modifier = new ToolModificationHandler();
|
||||
|
||||
this.setupMessageBusListener(this.messageBus);
|
||||
@@ -294,7 +297,7 @@ export class Scheduler {
|
||||
const currentApprovalMode = this.config.getApprovalMode();
|
||||
|
||||
try {
|
||||
const toolRegistry = this.config.getToolRegistry();
|
||||
const toolRegistry = this.context.toolRegistry;
|
||||
const newCalls: ToolCall[] = requests.map((request) => {
|
||||
const enrichedRequest: ToolCallRequestInfo = {
|
||||
...request,
|
||||
@@ -697,7 +700,7 @@ export class Scheduler {
|
||||
const originalRequestName =
|
||||
result.request.originalRequestName || result.request.name;
|
||||
|
||||
const newTool = this.config.getToolRegistry().getTool(tailRequest.name);
|
||||
const newTool = this.context.toolRegistry.getTool(tailRequest.name);
|
||||
|
||||
const newRequest: ToolCallRequestInfo = {
|
||||
callId: originalCallId,
|
||||
@@ -713,7 +716,7 @@ export class Scheduler {
|
||||
// Enqueue an errored tool call
|
||||
const errorCall = this._createToolNotFoundErroredToolCall(
|
||||
newRequest,
|
||||
this.config.getToolRegistry().getAllToolNames(),
|
||||
this.context.toolRegistry.getAllToolNames(),
|
||||
);
|
||||
this.state.replaceActiveCallWithTailCall(callId, errorCall);
|
||||
} else {
|
||||
|
||||
@@ -211,13 +211,15 @@ describe('Scheduler Parallel Execution', () => {
|
||||
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
toolRegistry: mockToolRegistry,
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
(mockConfig as unknown as { config: Config }).config = mockConfig as Config;
|
||||
|
||||
mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
|
||||
@@ -36,7 +36,7 @@ describe('Scheduler waiting callback', () => {
|
||||
|
||||
mockTool = new MockTool({ name: 'test_tool' });
|
||||
toolRegistry = new ToolRegistry(mockConfig, messageBus);
|
||||
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry);
|
||||
vi.spyOn(mockConfig, 'toolRegistry', 'get').mockReturnValue(toolRegistry);
|
||||
toolRegistry.registerTool(mockTool);
|
||||
|
||||
vi.mocked(checkPolicy).mockResolvedValue({
|
||||
|
||||
@@ -64,7 +64,7 @@ describe('ToolExecutor', () => {
|
||||
beforeEach(() => {
|
||||
// Use the standard fake config factory
|
||||
config = makeFakeConfig();
|
||||
executor = new ToolExecutor(config);
|
||||
executor = new ToolExecutor(config, config);
|
||||
|
||||
// Reset mocks
|
||||
vi.resetAllMocks();
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
type ToolCallResponseInfo,
|
||||
type ToolResult,
|
||||
type Config,
|
||||
type AgentLoopContext,
|
||||
type ToolLiveOutput,
|
||||
} from '../index.js';
|
||||
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
|
||||
@@ -49,7 +50,10 @@ export interface ToolExecutionContext {
|
||||
}
|
||||
|
||||
export class ToolExecutor {
|
||||
constructor(private readonly config: Config) {}
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
) {}
|
||||
|
||||
async execute(context: ToolExecutionContext): Promise<CompletedToolCall> {
|
||||
const { call, signal, outputUpdateHandler, onUpdateToolCall } = context;
|
||||
@@ -202,7 +206,7 @@ export class ToolExecutor {
|
||||
toolName,
|
||||
callId,
|
||||
this.config.storage.getProjectTempDir(),
|
||||
this.config.getSessionId(),
|
||||
this.context.promptId,
|
||||
);
|
||||
outputFile = savedPath;
|
||||
const truncatedContent = formatTruncatedToolOutput(
|
||||
@@ -241,7 +245,7 @@ export class ToolExecutor {
|
||||
toolName,
|
||||
callId,
|
||||
this.config.storage.getProjectTempDir(),
|
||||
this.config.getSessionId(),
|
||||
this.context.promptId,
|
||||
);
|
||||
outputFile = savedPath;
|
||||
const truncatedText = formatTruncatedToolOutput(
|
||||
|
||||
Reference in New Issue
Block a user