diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index be85d8ba73..45584a9d46 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -88,14 +88,13 @@ export const ToolConfirmationMessage: React.FC< const settings = useSettings(); const allowPermanentApproval = settings.merged.security.enablePermanentToolApproval && - (config?.getDisableAlwaysAllow ? !config.getDisableAlwaysAllow() : true); + !config.getDisableAlwaysAllow(); const handlesOwnUI = confirmationDetails.type === 'ask_user' || confirmationDetails.type === 'exit_plan_mode'; const isTrustedFolder = - config?.isTrustedFolder?.() && - (config?.getDisableAlwaysAllow ? !config.getDisableAlwaysAllow() : true); + config.isTrustedFolder() && !config.getDisableAlwaysAllow(); const handleConfirm = useCallback( (outcome: ToolConfirmationOutcome, payload?: ToolConfirmationPayload) => { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index e1db5c6e8e..c44004e5b1 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -3446,4 +3446,38 @@ describe('ConfigSchema validation', () => { expect(result.data.sandbox?.networkAccess).toBe(false); } }); + + describe('AgentLoopContext Spread Safety', () => { + it('should preserve AgentLoopContext properties when Config is spread', async () => { + const config = new Config({ + targetDir: '/tmp/test', + sessionId: 'test-session', + debugMode: false, + cwd: '/tmp/test', + model: 'auto', + }); + await config.initialize(); + + // Spread the config instance into a new object + const context: AgentLoopContext = { ...config }; + + // Verify all AgentLoopContext properties are present + expect(context.promptId).toBe('test-session'); + expect(context.config).toBe(config); + expect(context.toolRegistry).toBeDefined(); + expect(context.promptRegistry).toBeDefined(); + expect(context.resourceRegistry).toBeDefined(); + expect(context.messageBus).toBeDefined(); + expect(context.geminiClient).toBeDefined(); + expect(context.sandboxManager).toBeDefined(); + + // Verify they are the same instances + expect(context.toolRegistry).toBe(config.toolRegistry); + expect(context.promptRegistry).toBe(config.promptRegistry); + expect(context.resourceRegistry).toBe(config.resourceRegistry); + expect(context.messageBus).toBe(config.messageBus); + expect(context.geminiClient).toBe(config.geminiClient); + expect(context.sandboxManager).toBe(config.sandboxManager); + }); + }); }); diff --git a/packages/core/src/scheduler/policy.ts b/packages/core/src/scheduler/policy.ts index a59aa16510..82a613d4cd 100644 --- a/packages/core/src/scheduler/policy.ts +++ b/packages/core/src/scheduler/policy.ts @@ -119,7 +119,7 @@ export async function updatePolicy( ): Promise { // Mode Transitions (AUTO_EDIT) if (isAutoEditTransition(tool, outcome)) { - context.config?.setApprovalMode?.(ApprovalMode.AUTO_EDIT); + context.config.setApprovalMode(ApprovalMode.AUTO_EDIT); return; } @@ -128,8 +128,8 @@ export async function updatePolicy( if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) { // If folder is trusted and workspace policies are enabled, we prefer workspace scope. if ( - context.config?.isTrustedFolder?.() && - context.config?.getWorkspacePoliciesDir?.() !== undefined + context.config.isTrustedFolder() && + context.config.getWorkspacePoliciesDir() !== undefined ) { persistScope = 'workspace'; } else { diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 6301805814..821e30c539 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -45,8 +45,6 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { getShellDefinition } from './definitions/coreTools.js'; import { resolveToolDeclaration } from './definitions/resolver.js'; import type { AgentLoopContext } from '../config/agent-loop-context.js'; -import type { Config } from '../config/config.js'; -import type { GeminiClient } from '../core/client.js'; export const OUTPUT_UPDATE_INTERVAL_MS = 1000; @@ -65,12 +63,11 @@ export class ShellToolInvocation extends BaseToolInvocation< ToolResult > { constructor( - private readonly config: Config, + private readonly context: AgentLoopContext, params: ShellToolParams, messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, - private readonly geminiClient?: GeminiClient, ) { super(params, messageBus, _toolName, _toolDisplayName); } @@ -171,7 +168,7 @@ export class ShellToolInvocation extends BaseToolInvocation< .toString('hex')}.tmp`; const tempFilePath = path.join(os.tmpdir(), tempFileName); - const timeoutMs = this.config.getShellToolInactivityTimeout(); + const timeoutMs = this.context.config.getShellToolInactivityTimeout(); const timeoutController = new AbortController(); let timeoutTimer: NodeJS.Timeout | undefined; @@ -192,10 +189,10 @@ export class ShellToolInvocation extends BaseToolInvocation< })(); const cwd = this.params.dir_path - ? path.resolve(this.config.getTargetDir(), this.params.dir_path) - : this.config.getTargetDir(); + ? path.resolve(this.context.config.getTargetDir(), this.params.dir_path) + : this.context.config.getTargetDir(); - const validationError = this.config.validatePathAccess(cwd); + const validationError = this.context.config.validatePathAccess(cwd); if (validationError) { return { llmContent: validationError, @@ -274,14 +271,14 @@ export class ShellToolInvocation extends BaseToolInvocation< } }, combinedController.signal, - this.config.getEnableInteractiveShell(), + this.context.config.getEnableInteractiveShell(), { ...shellExecutionConfig, pager: 'cat', sanitizationConfig: shellExecutionConfig?.sanitizationConfig ?? - this.config.sanitizationConfig, - sandboxManager: this.config.sandboxManager, + this.context.config.sanitizationConfig, + sandboxManager: this.context.config.sandboxManager, }, ); @@ -386,7 +383,7 @@ export class ShellToolInvocation extends BaseToolInvocation< } let returnDisplayMessage = ''; - if (this.config.getDebugMode()) { + if (this.context.config.getDebugMode()) { returnDisplayMessage = llmContent; } else { if (this.params.is_background || result.backgrounded) { @@ -415,7 +412,8 @@ export class ShellToolInvocation extends BaseToolInvocation< } } - const summarizeConfig = this.config.getSummarizeToolOutputConfig(); + const summarizeConfig = + this.context.config.getSummarizeToolOutputConfig(); const executionError = result.error ? { error: { @@ -426,10 +424,10 @@ export class ShellToolInvocation extends BaseToolInvocation< : {}; if (summarizeConfig && summarizeConfig[SHELL_TOOL_NAME]) { const summary = await summarizeToolOutput( - this.config, + this.context.config, { model: 'summarizer-shell' }, llmContent, - this.geminiClient ?? this.config.getGeminiClient(), + this.context.geminiClient, signal, ); return { @@ -463,17 +461,17 @@ export class ShellTool extends BaseDeclarativeTool< ToolResult > { static readonly Name = SHELL_TOOL_NAME; - private readonly config: Config; - private readonly geminiClient?: GeminiClient; - constructor(context: Config | AgentLoopContext, messageBus: MessageBus) { + constructor( + private readonly context: AgentLoopContext, + messageBus: MessageBus, + ) { void initializeShellParsers().catch(() => { // Errors are surfaced when parsing commands. }); - const config = 'config' in context ? context.config : context; const definition = getShellDefinition( - config.getEnableInteractiveShell(), - config.getEnableShellOutputEfficiency(), + context.config.getEnableInteractiveShell(), + context.config.getEnableShellOutputEfficiency(), ); super( ShellTool.Name, @@ -485,10 +483,6 @@ export class ShellTool extends BaseDeclarativeTool< false, // isOutputMarkdown true, // canUpdateOutput ); - this.config = config; - if ('config' in context) { - this.geminiClient = context.geminiClient; - } } protected override validateToolParamValues( @@ -500,10 +494,10 @@ export class ShellTool extends BaseDeclarativeTool< if (params.dir_path) { const resolvedPath = path.resolve( - this.config.getTargetDir(), + this.context.config.getTargetDir(), params.dir_path, ); - return this.config.validatePathAccess(resolvedPath); + return this.context.config.validatePathAccess(resolvedPath); } return null; } @@ -515,19 +509,18 @@ export class ShellTool extends BaseDeclarativeTool< _toolDisplayName?: string, ): ToolInvocation { return new ShellToolInvocation( - this.config, + this.context, params, messageBus, _toolName, _toolDisplayName, - this.geminiClient, ); } override getSchema(modelId?: string) { const definition = getShellDefinition( - this.config.getEnableInteractiveShell(), - this.config.getEnableShellOutputEfficiency(), + this.context.config.getEnableInteractiveShell(), + this.context.config.getEnableShellOutputEfficiency(), ); return resolveToolDeclaration(definition, modelId); } diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 0fe53bdf7a..ecd0ecddea 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -38,8 +38,6 @@ import { WEB_FETCH_DEFINITION } from './definitions/coreTools.js'; import { resolveToolDeclaration } from './definitions/resolver.js'; import { LRUCache } from 'mnemonist'; import type { AgentLoopContext } from '../config/agent-loop-context.js'; -import type { Config } from '../config/config.js'; -import type { GeminiClient } from '../core/client.js'; const URL_FETCH_TIMEOUT_MS = 10000; const MAX_CONTENT_LENGTH = 250000; @@ -227,18 +225,17 @@ class WebFetchToolInvocation extends BaseToolInvocation< ToolResult > { constructor( - private readonly config: Config, + private readonly context: AgentLoopContext, params: WebFetchToolParams, messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, - private readonly geminiClient?: GeminiClient, ) { super(params, messageBus, _toolName, _toolDisplayName); } private handleRetry(attempt: number, error: unknown, delayMs: number): void { - const maxAttempts = this.config.getMaxAttempts(); + const maxAttempts = this.context.config.getMaxAttempts(); const modelName = 'Web Fetch'; const errorType = getRetryErrorType(error); @@ -251,7 +248,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< }); logNetworkRetryAttempt( - this.config, + this.context.config, new NetworkRetryAttemptEvent( attempt, maxAttempts, @@ -305,7 +302,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< return res; }, { - retryFetchErrors: this.config.getRetryFetchErrors(), + retryFetchErrors: this.context.config.getRetryFetchErrors(), onRetry: (attempt, error, delayMs) => this.handleRetry(attempt, error, delayMs), signal, @@ -352,7 +349,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< `[WebFetchTool] Skipped private or local host: ${url}`, ); logWebFetchFallbackAttempt( - this.config, + this.context.config, new WebFetchFallbackAttemptEvent('private_ip_skipped'), ); skipped.push(`[Blocked Host] ${url}`); @@ -437,7 +434,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< .join('\n'); try { - const geminiClient = this.geminiClient ?? this.config.getGeminiClient(); + const geminiClient = this.context.geminiClient; const fallbackPrompt = `Follow the user's instructions below using the provided webpage content. @@ -520,7 +517,7 @@ ${aggregatedContent} ): Promise { // Check for AUTO_EDIT approval mode. This tool has a specific behavior // where ProceedAlways switches the entire session to AUTO_EDIT. - if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { + if (this.context.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { return false; } @@ -643,7 +640,7 @@ ${aggregatedContent} return res; }, { - retryFetchErrors: this.config.getRetryFetchErrors(), + retryFetchErrors: this.context.config.getRetryFetchErrors(), onRetry: (attempt, error, delayMs) => this.handleRetry(attempt, error, delayMs), signal, @@ -754,7 +751,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun } async execute(signal: AbortSignal): Promise { - if (this.config.getDirectWebFetch()) { + if (this.context.config.getDirectWebFetch()) { return this.executeExperimental(signal); } const userPrompt = this.params.prompt!; @@ -777,7 +774,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun } try { - const geminiClient = this.geminiClient ?? this.config.getGeminiClient(); + const geminiClient = this.context.geminiClient; const sanitizedPrompt = `Follow the user's instructions to process the authorized URLs. @@ -869,7 +866,7 @@ ${toFetch.join('\n')} `[WebFetchTool] Primary fetch failed, falling back: ${getErrorMessage(error)}`, ); logWebFetchFallbackAttempt( - this.config, + this.context.config, new WebFetchFallbackAttemptEvent('primary_failed'), ); // Simple All-or-Nothing Fallback @@ -886,11 +883,11 @@ export class WebFetchTool extends BaseDeclarativeTool< ToolResult > { static readonly Name = WEB_FETCH_TOOL_NAME; - private readonly config: Config; - private readonly geminiClient?: GeminiClient; - constructor(context: Config | AgentLoopContext, messageBus: MessageBus) { - const config = 'config' in context ? context.config : context; + constructor( + private readonly context: AgentLoopContext, + messageBus: MessageBus, + ) { super( WebFetchTool.Name, 'WebFetch', @@ -901,16 +898,12 @@ export class WebFetchTool extends BaseDeclarativeTool< true, // isOutputMarkdown false, // canUpdateOutput ); - this.config = config; - if ('config' in context) { - this.geminiClient = context.geminiClient; - } } protected override validateToolParamValues( params: WebFetchToolParams, ): string | null { - if (this.config.getDirectWebFetch()) { + if (this.context.config.getDirectWebFetch()) { if (!params.url) { return "The 'url' parameter is required."; } @@ -946,18 +939,17 @@ export class WebFetchTool extends BaseDeclarativeTool< _toolDisplayName?: string, ): ToolInvocation { return new WebFetchToolInvocation( - this.config, + this.context, params, messageBus, _toolName, _toolDisplayName, - this.geminiClient, ); } override getSchema(modelId?: string) { const schema = resolveToolDeclaration(WEB_FETCH_DEFINITION, modelId); - if (this.config.getDirectWebFetch()) { + if (this.context.config.getDirectWebFetch()) { return { ...schema, description: diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index 49b8a4f69d..1e7174ca71 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -23,8 +23,6 @@ import { WEB_SEARCH_DEFINITION } from './definitions/coreTools.js'; import { resolveToolDeclaration } from './definitions/resolver.js'; import { LlmRole } from '../telemetry/llmRole.js'; import type { AgentLoopContext } from '../config/agent-loop-context.js'; -import type { Config } from '../config/config.js'; -import type { GeminiClient } from '../core/client.js'; interface GroundingChunkWeb { uri?: string; @@ -73,12 +71,11 @@ class WebSearchToolInvocation extends BaseToolInvocation< WebSearchToolResult > { constructor( - private readonly config: Config, + private readonly context: AgentLoopContext, params: WebSearchToolParams, messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, - private readonly geminiClient?: GeminiClient, ) { super(params, messageBus, _toolName, _toolDisplayName); } @@ -88,7 +85,7 @@ class WebSearchToolInvocation extends BaseToolInvocation< } async execute(signal: AbortSignal): Promise { - const geminiClient = this.geminiClient ?? this.config.getGeminiClient(); + const geminiClient = this.context.geminiClient; try { const response = await geminiClient.generateContent( @@ -208,11 +205,11 @@ export class WebSearchTool extends BaseDeclarativeTool< WebSearchToolResult > { static readonly Name = WEB_SEARCH_TOOL_NAME; - private readonly config: Config; - private readonly geminiClient?: GeminiClient; - constructor(context: Config | AgentLoopContext, messageBus: MessageBus) { - const config = 'config' in context ? context.config : context; + constructor( + private readonly context: AgentLoopContext, + messageBus: MessageBus, + ) { super( WebSearchTool.Name, 'GoogleSearch', @@ -223,10 +220,6 @@ export class WebSearchTool extends BaseDeclarativeTool< true, // isOutputMarkdown false, // canUpdateOutput ); - this.config = config; - if ('config' in context) { - this.geminiClient = context.geminiClient; - } } /** @@ -250,12 +243,11 @@ export class WebSearchTool extends BaseDeclarativeTool< _toolDisplayName?: string, ): ToolInvocation { return new WebSearchToolInvocation( - this.config, + this.context, params, messageBus ?? this.messageBus, _toolName, _toolDisplayName, - this.geminiClient, ); }