mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
feat(core): Fully migrate packages/core to AgentLoopContext. (#22115)
This commit is contained in:
@@ -47,6 +47,9 @@ describe('Tool Confirmation Policy Updates', () => {
|
||||
} as unknown as MessageBus;
|
||||
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -302,7 +302,7 @@ export class McpClient implements McpProgressReporter {
|
||||
this.serverConfig,
|
||||
this.client!,
|
||||
cliConfig,
|
||||
this.toolRegistry.getMessageBus(),
|
||||
this.toolRegistry.messageBus,
|
||||
{
|
||||
...(options ?? {
|
||||
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
@@ -1167,7 +1167,7 @@ export async function connectAndDiscover(
|
||||
mcpServerConfig,
|
||||
mcpClient,
|
||||
cliConfig,
|
||||
toolRegistry.getMessageBus(),
|
||||
toolRegistry.messageBus,
|
||||
{ timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC },
|
||||
);
|
||||
|
||||
|
||||
@@ -94,6 +94,13 @@ describe('ShellTool', () => {
|
||||
fs.mkdirSync(path.join(tempRootDir, 'subdir'));
|
||||
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
geminiClient: {
|
||||
stripThoughtsFromHistory: vi.fn(),
|
||||
},
|
||||
|
||||
getAllowedTools: vi.fn().mockReturnValue([]),
|
||||
getApprovalMode: vi.fn().mockReturnValue('strict'),
|
||||
getCoreTools: vi.fn().mockReturnValue([]),
|
||||
@@ -441,7 +448,7 @@ describe('ShellTool', () => {
|
||||
mockConfig,
|
||||
{ model: 'summarizer-shell' },
|
||||
expect.any(String),
|
||||
mockConfig.getGeminiClient(),
|
||||
mockConfig.geminiClient,
|
||||
mockAbortSignal,
|
||||
);
|
||||
expect(result.llmContent).toBe('summarized output');
|
||||
|
||||
@@ -8,7 +8,6 @@ import fsPromises from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import os from 'node:os';
|
||||
import crypto from 'node:crypto';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { debugLogger } from '../index.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import {
|
||||
@@ -45,6 +44,7 @@ import { SHELL_TOOL_NAME } from './tool-names.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { getShellDefinition } from './definitions/coreTools.js';
|
||||
import { resolveToolDeclaration } from './definitions/resolver.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
|
||||
export const OUTPUT_UPDATE_INTERVAL_MS = 1000;
|
||||
|
||||
@@ -63,7 +63,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
params: ShellToolParams,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
@@ -168,7 +168,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
.toString('hex')}.tmp`;
|
||||
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
||||
|
||||
const timeoutMs = this.config.getShellToolInactivityTimeout();
|
||||
const timeoutMs = this.context.config.getShellToolInactivityTimeout();
|
||||
const timeoutController = new AbortController();
|
||||
let timeoutTimer: NodeJS.Timeout | undefined;
|
||||
|
||||
@@ -189,10 +189,10 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
})();
|
||||
|
||||
const cwd = this.params.dir_path
|
||||
? path.resolve(this.config.getTargetDir(), this.params.dir_path)
|
||||
: this.config.getTargetDir();
|
||||
? path.resolve(this.context.config.getTargetDir(), this.params.dir_path)
|
||||
: this.context.config.getTargetDir();
|
||||
|
||||
const validationError = this.config.validatePathAccess(cwd);
|
||||
const validationError = this.context.config.validatePathAccess(cwd);
|
||||
if (validationError) {
|
||||
return {
|
||||
llmContent: validationError,
|
||||
@@ -271,13 +271,13 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
},
|
||||
combinedController.signal,
|
||||
this.config.getEnableInteractiveShell(),
|
||||
this.context.config.getEnableInteractiveShell(),
|
||||
{
|
||||
...shellExecutionConfig,
|
||||
pager: 'cat',
|
||||
sanitizationConfig:
|
||||
shellExecutionConfig?.sanitizationConfig ??
|
||||
this.config.sanitizationConfig,
|
||||
this.context.config.sanitizationConfig,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -382,7 +382,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
let returnDisplayMessage = '';
|
||||
if (this.config.getDebugMode()) {
|
||||
if (this.context.config.getDebugMode()) {
|
||||
returnDisplayMessage = llmContent;
|
||||
} else {
|
||||
if (this.params.is_background || result.backgrounded) {
|
||||
@@ -411,7 +411,8 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
}
|
||||
|
||||
const summarizeConfig = this.config.getSummarizeToolOutputConfig();
|
||||
const summarizeConfig =
|
||||
this.context.config.getSummarizeToolOutputConfig();
|
||||
const executionError = result.error
|
||||
? {
|
||||
error: {
|
||||
@@ -422,10 +423,10 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
: {};
|
||||
if (summarizeConfig && summarizeConfig[SHELL_TOOL_NAME]) {
|
||||
const summary = await summarizeToolOutput(
|
||||
this.config,
|
||||
this.context.config,
|
||||
{ model: 'summarizer-shell' },
|
||||
llmContent,
|
||||
this.config.getGeminiClient(),
|
||||
this.context.geminiClient,
|
||||
signal,
|
||||
);
|
||||
return {
|
||||
@@ -461,15 +462,15 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
static readonly Name = SHELL_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
void initializeShellParsers().catch(() => {
|
||||
// Errors are surfaced when parsing commands.
|
||||
});
|
||||
const definition = getShellDefinition(
|
||||
config.getEnableInteractiveShell(),
|
||||
config.getEnableShellOutputEfficiency(),
|
||||
context.config.getEnableInteractiveShell(),
|
||||
context.config.getEnableShellOutputEfficiency(),
|
||||
);
|
||||
super(
|
||||
ShellTool.Name,
|
||||
@@ -492,10 +493,10 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
if (params.dir_path) {
|
||||
const resolvedPath = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
this.context.config.getTargetDir(),
|
||||
params.dir_path,
|
||||
);
|
||||
return this.config.validatePathAccess(resolvedPath);
|
||||
return this.context.config.validatePathAccess(resolvedPath);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -507,7 +508,7 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ShellToolParams, ToolResult> {
|
||||
return new ShellToolInvocation(
|
||||
this.config,
|
||||
this.context.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
@@ -517,8 +518,8 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
override getSchema(modelId?: string) {
|
||||
const definition = getShellDefinition(
|
||||
this.config.getEnableInteractiveShell(),
|
||||
this.config.getEnableShellOutputEfficiency(),
|
||||
this.context.config.getEnableInteractiveShell(),
|
||||
this.context.config.getEnableShellOutputEfficiency(),
|
||||
);
|
||||
return resolveToolDeclaration(definition, modelId);
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ export class ToolRegistry {
|
||||
// and `isActive` to get only the active tools.
|
||||
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private messageBus: MessageBus;
|
||||
readonly messageBus: MessageBus;
|
||||
|
||||
constructor(config: Config, messageBus: MessageBus) {
|
||||
this.config = config;
|
||||
|
||||
@@ -277,6 +277,12 @@ describe('WebFetchTool', () => {
|
||||
setApprovalMode: vi.fn(),
|
||||
getProxy: vi.fn(),
|
||||
getGeminiClient: mockGetGeminiClient,
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
get geminiClient() {
|
||||
return mockGetGeminiClient();
|
||||
},
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
getMaxAttempts: vi.fn().mockReturnValue(3),
|
||||
getDirectWebFetch: vi.fn().mockReturnValue(false),
|
||||
|
||||
@@ -18,7 +18,6 @@ import { buildParamArgsPattern } from '../policy/utils.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
|
||||
@@ -38,6 +37,7 @@ import { retryWithBackoff, getRetryErrorType } from '../utils/retry.js';
|
||||
import { WEB_FETCH_DEFINITION } from './definitions/coreTools.js';
|
||||
import { resolveToolDeclaration } from './definitions/resolver.js';
|
||||
import { LRUCache } from 'mnemonist';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
|
||||
const URL_FETCH_TIMEOUT_MS = 10000;
|
||||
const MAX_CONTENT_LENGTH = 100000;
|
||||
@@ -213,7 +213,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
params: WebFetchToolParams,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
@@ -223,7 +223,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
private handleRetry(attempt: number, error: unknown, delayMs: number): void {
|
||||
const maxAttempts = this.config.getMaxAttempts();
|
||||
const maxAttempts = this.context.config.getMaxAttempts();
|
||||
const modelName = 'Web Fetch';
|
||||
const errorType = getRetryErrorType(error);
|
||||
|
||||
@@ -236,7 +236,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
});
|
||||
|
||||
logNetworkRetryAttempt(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new NetworkRetryAttemptEvent(
|
||||
attempt,
|
||||
maxAttempts,
|
||||
@@ -290,7 +290,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
return res;
|
||||
},
|
||||
{
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
retryFetchErrors: this.context.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) =>
|
||||
this.handleRetry(attempt, error, delayMs),
|
||||
signal,
|
||||
@@ -342,7 +342,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
`[WebFetchTool] Skipped private or local host: ${url}`,
|
||||
);
|
||||
logWebFetchFallbackAttempt(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new WebFetchFallbackAttemptEvent('private_ip_skipped'),
|
||||
);
|
||||
skipped.push(`[Blocked Host] ${url}`);
|
||||
@@ -379,7 +379,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
.join('\n\n---\n\n');
|
||||
|
||||
try {
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const geminiClient = this.context.geminiClient;
|
||||
const fallbackPrompt = `The user requested the following: "${this.params.prompt}".
|
||||
|
||||
I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again.
|
||||
@@ -458,7 +458,7 @@ ${aggregatedContent}
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Check for AUTO_EDIT approval mode. This tool has a specific behavior
|
||||
// where ProceedAlways switches the entire session to AUTO_EDIT.
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
if (this.context.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -581,7 +581,7 @@ ${aggregatedContent}
|
||||
return res;
|
||||
},
|
||||
{
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
retryFetchErrors: this.context.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) =>
|
||||
this.handleRetry(attempt, error, delayMs),
|
||||
signal,
|
||||
@@ -692,7 +692,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
|
||||
}
|
||||
|
||||
async execute(signal: AbortSignal): Promise<ToolResult> {
|
||||
if (this.config.getDirectWebFetch()) {
|
||||
if (this.context.config.getDirectWebFetch()) {
|
||||
return this.executeExperimental(signal);
|
||||
}
|
||||
const userPrompt = this.params.prompt!;
|
||||
@@ -715,7 +715,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
|
||||
}
|
||||
|
||||
try {
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const geminiClient = this.context.geminiClient;
|
||||
const response = await geminiClient.generateContent(
|
||||
{ model: 'web-fetch' },
|
||||
[{ role: 'user', parts: [{ text: userPrompt }] }],
|
||||
@@ -797,7 +797,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
|
||||
`[WebFetchTool] Primary fetch failed, falling back: ${getErrorMessage(error)}`,
|
||||
);
|
||||
logWebFetchFallbackAttempt(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new WebFetchFallbackAttemptEvent('primary_failed'),
|
||||
);
|
||||
// Simple All-or-Nothing Fallback
|
||||
@@ -816,7 +816,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
static readonly Name = WEB_FETCH_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
@@ -834,7 +834,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
protected override validateToolParamValues(
|
||||
params: WebFetchToolParams,
|
||||
): string | null {
|
||||
if (this.config.getDirectWebFetch()) {
|
||||
if (this.context.config.getDirectWebFetch()) {
|
||||
if (!params.url) {
|
||||
return "The 'url' parameter is required.";
|
||||
}
|
||||
@@ -870,7 +870,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebFetchToolParams, ToolResult> {
|
||||
return new WebFetchToolInvocation(
|
||||
this.config,
|
||||
this.context.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
@@ -880,7 +880,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
|
||||
override getSchema(modelId?: string) {
|
||||
const schema = resolveToolDeclaration(WEB_FETCH_DEFINITION, modelId);
|
||||
if (this.config.getDirectWebFetch()) {
|
||||
if (this.context.config.getDirectWebFetch()) {
|
||||
return {
|
||||
...schema,
|
||||
description:
|
||||
|
||||
@@ -31,6 +31,9 @@ describe('WebSearchTool', () => {
|
||||
beforeEach(() => {
|
||||
const mockConfigInstance = {
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
get geminiClient() {
|
||||
return mockGeminiClient;
|
||||
},
|
||||
getProxy: () => undefined,
|
||||
generationConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
|
||||
@@ -17,12 +17,12 @@ import {
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
|
||||
import { getErrorMessage, isAbortError } from '../utils/errors.js';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { WEB_SEARCH_DEFINITION } from './definitions/coreTools.js';
|
||||
import { resolveToolDeclaration } from './definitions/resolver.js';
|
||||
import { LlmRole } from '../telemetry/llmRole.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
|
||||
interface GroundingChunkWeb {
|
||||
uri?: string;
|
||||
@@ -71,7 +71,7 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
WebSearchToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
params: WebSearchToolParams,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
@@ -85,7 +85,7 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
async execute(signal: AbortSignal): Promise<WebSearchToolResult> {
|
||||
const geminiClient = this.config.getGeminiClient();
|
||||
const geminiClient = this.context.geminiClient;
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
@@ -207,7 +207,7 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
static readonly Name = WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly context: AgentLoopContext,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
@@ -243,7 +243,7 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
|
||||
return new WebSearchToolInvocation(
|
||||
this.config,
|
||||
this.context.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
_toolName,
|
||||
|
||||
Reference in New Issue
Block a user