feat(core): Support auto-distillation for tool output.

This commit is contained in:
Your Name
2026-03-11 01:18:41 +00:00
parent 39d3b0e28c
commit d6173843b0
23 changed files with 1104 additions and 238 deletions
+16 -7
View File
@@ -537,6 +537,7 @@ export interface ConfigParameters {
mcpEnablementCallbacks?: McpEnablementCallbacks;
userMemory?: string | HierarchicalMemory;
geminiMdFileCount?: number;
contentGenerator?: ContentGenerator;
geminiMdFilePaths?: string[];
approvalMode?: ApprovalMode;
showMemoryUsage?: boolean;
@@ -608,7 +609,7 @@ export interface ConfigParameters {
maxAttempts?: number;
enableShellOutputEfficiency?: boolean;
shellToolInactivityTimeout?: number;
fakeResponses?: string;
fakeResponses?: string | any[];
recordResponses?: string;
ptyInfo?: string;
disableYoloMode?: boolean;
@@ -672,6 +673,7 @@ export class Config implements McpContext, AgentLoopContext {
private trackerService?: TrackerService;
private contentGeneratorConfig!: ContentGeneratorConfig;
private contentGenerator!: ContentGenerator;
private _initialContentGenerator?: ContentGenerator;
readonly modelConfigService: ModelConfigService;
private readonly embeddingModel: string;
private readonly sandbox: SandboxConfig | undefined;
@@ -813,7 +815,7 @@ export class Config implements McpContext, AgentLoopContext {
private readonly maxAttempts: number;
private readonly enableShellOutputEfficiency: boolean;
private readonly shellToolInactivityTimeout: number;
readonly fakeResponses?: string;
readonly fakeResponses?: string | any[];
readonly recordResponses?: string;
private readonly disableYoloMode: boolean;
private readonly disableAlwaysAllow: boolean;
@@ -894,6 +896,7 @@ export class Config implements McpContext, AgentLoopContext {
this.pendingIncludeDirectories = params.includeDirectories ?? [];
this.debugMode = params.debugMode;
this.question = params.question;
this._initialContentGenerator = params.contentGenerator;
this.coreTools = params.coreTools;
this.mainAgentTools = params.mainAgentTools;
@@ -1356,11 +1359,17 @@ export class Config implements McpContext, AgentLoopContext {
baseUrl,
customHeaders,
);
this.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
this,
this.getSessionId(),
);
if (this._initialContentGenerator) {
this.contentGenerator = this._initialContentGenerator;
// We only use it once, on first initialization. Future refreshes will create real ones
// unless we want it to persist forever, but usually AppRig manages this.
} else {
this.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
this,
this.getSessionId(),
);
}
// Only assign to instance properties after successful initialization
this.contentGeneratorConfig = newContentGeneratorConfig;
+33 -21
View File
@@ -21,10 +21,11 @@ import type { UserTierId, GeminiUserTier } from '../code_assist/types.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { InstallationManager } from '../utils/installationManager.js';
import { FakeContentGenerator } from './fakeContentGenerator.js';
import { FallbackContentGenerator } from './fallbackContentGenerator.js';
import { parseCustomHeaders } from '../utils/customHeaderUtils.js';
import { determineSurface } from '../utils/surface.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { getVersion, resolveModel } from '../../index.js';
import { debugLogger, getVersion, resolveModel } from '../../index.js';
import type { LlmRole } from '../telemetry/llmRole.js';
/**
@@ -47,6 +48,8 @@ export interface ContentGenerator {
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
getSentRequests?(): GenerateContentParameters[];
userTier?: UserTierId;
userTierName?: string;
@@ -166,12 +169,6 @@ export async function createContentGenerator(
sessionId?: string,
): Promise<ContentGenerator> {
const generator = await (async () => {
if (gcConfig.fakeResponses) {
const fakeGenerator = await FakeContentGenerator.fromFile(
gcConfig.fakeResponses,
);
return new LoggingContentGenerator(fakeGenerator, gcConfig);
}
const version = await getVersion();
const model = resolveModel(
gcConfig.getModel(),
@@ -208,23 +205,21 @@ export async function createContentGenerator(
) {
baseHeaders['Authorization'] = `Bearer ${config.apiKey}`;
}
let realGenerator: ContentGenerator;
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.COMPUTE_ADC
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
),
realGenerator = await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
);
}
if (
} else if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI ||
config.authType === AuthType.GATEWAY
@@ -268,11 +263,28 @@ export async function createContentGenerator(
httpOptions,
...(apiVersionEnv && { apiVersion: apiVersionEnv }),
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
realGenerator = googleGenAI.models;
} else {
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
);
}
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
);
let targetGenerator = realGenerator;
if (gcConfig.fakeResponses) {
if (Array.isArray(gcConfig.fakeResponses)) {
debugLogger.log(`[createContentGenerator] Instantiating FakeContentGenerator with ${gcConfig.fakeResponses.length} in-memory mock responses.`);
const fakeGen = new FakeContentGenerator(gcConfig.fakeResponses);
targetGenerator = new FallbackContentGenerator(fakeGen, realGenerator);
} else {
debugLogger.log(`[createContentGenerator] Instantiating FakeContentGenerator from file: ${gcConfig.fakeResponses}`);
const fakeGen = await FakeContentGenerator.fromFile(gcConfig.fakeResponses);
targetGenerator = new FallbackContentGenerator(fakeGen, realGenerator);
}
}
return new LoggingContentGenerator(targetGenerator, gcConfig);
})();
if (gcConfig.recordResponses) {
+40 -20
View File
@@ -18,6 +18,16 @@ import type { UserTierId, GeminiUserTier } from '../code_assist/types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { LlmRole } from '../telemetry/types.js';
export class MockExhaustedError extends Error {
constructor(method: string, request?: unknown) {
super(
`No more mock responses for ${method}, got request:\n` +
safeJsonStringify(request),
);
this.name = 'MockExhaustedError';
}
}
export type FakeResponse =
| {
method: 'generateContent';
@@ -42,13 +52,20 @@ export type FakeResponse =
// CLI argument.
export class FakeContentGenerator implements ContentGenerator {
private callCounter = 0;
private sentRequests: GenerateContentParameters[] = [];
userTier?: UserTierId;
userTierName?: string;
paidTier?: GeminiUserTier;
constructor(private readonly responses: FakeResponse[]) {}
static async fromFile(filePath: string): Promise<FakeContentGenerator> {
getSentRequests(): GenerateContentParameters[] {
return this.sentRequests;
}
static async fromFile(
filePath: string,
): Promise<FakeContentGenerator> {
const fileContent = await promises.readFile(filePath, 'utf-8');
const responses = fileContent
.split('\n')
@@ -62,13 +79,14 @@ export class FakeContentGenerator implements ContentGenerator {
M extends FakeResponse['method'],
R = Extract<FakeResponse, { method: M }>['response'],
>(method: M, request: unknown): R {
const response = this.responses[this.callCounter++];
const response = this.responses[this.callCounter];
if (!response) {
throw new Error(
`No more mock responses for ${method}, got request:\n` +
safeJsonStringify(request),
);
throw new MockExhaustedError(method, request);
}
// We only increment the counter if we actually consume a mock response
this.callCounter++;
if (response.method !== method) {
throw new Error(
`Unexpected response type, next response was for ${response.method} but expected ${method}`,
@@ -80,26 +98,29 @@ export class FakeContentGenerator implements ContentGenerator {
async generateContent(
request: GenerateContentParameters,
_userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> {
this.sentRequests.push(request);
const next = this.getNextResponse('generateContent', request);
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return Object.setPrototypeOf(
this.getNextResponse('generateContent', request),
GenerateContentResponse.prototype,
);
return Object.setPrototypeOf(next, GenerateContentResponse.prototype);
}
async generateContentStream(
request: GenerateContentParameters,
_userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> {
this.sentRequests.push(request);
const responses = this.getNextResponse('generateContentStream', request);
async function* stream() {
for (const response of responses) {
// Add a tiny delay to ensure React has time to render the 'Responding'
// state. If the mock stream finishes synchronously, AppRig's
// awaitingResponse flag may never be cleared, causing the rig to hang.
await new Promise((resolve) => setTimeout(resolve, 5));
for (const response of responses!) {
yield Object.setPrototypeOf(
response,
GenerateContentResponse.prototype,
@@ -112,16 +133,15 @@ export class FakeContentGenerator implements ContentGenerator {
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.getNextResponse('countTokens', request);
const next = this.getNextResponse('countTokens', request);
return next;
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
const next = this.getNextResponse('embedContent', request);
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return Object.setPrototypeOf(
this.getNextResponse('embedContent', request),
EmbedContentResponse.prototype,
);
return Object.setPrototypeOf(next, EmbedContentResponse.prototype);
}
}
@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ContentGenerator } from './contentGenerator.js';
import type { GenerateContentParameters, GenerateContentResponse, CountTokensParameters, CountTokensResponse, EmbedContentParameters, EmbedContentResponse } from '@google/genai';
import type { LlmRole } from '../telemetry/types.js';
import { debugLogger } from '../utils/debugLogger.js';
import { MockExhaustedError } from './fakeContentGenerator.js';
/**
* A ContentGenerator that attempts to use a primary generator,
* and falls back to a secondary generator if the primary throws MockExhaustedError.
*/
export class FallbackContentGenerator implements ContentGenerator {
get userTier() { return this.primary.userTier; }
get userTierName() { return this.primary.userTierName; }
get paidTier() { return this.primary.paidTier; }
constructor(
private readonly primary: ContentGenerator,
private readonly fallback: ContentGenerator,
private readonly onFallback?: (method: string) => void,
) {}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> {
try {
return await this.primary.generateContent(request, userPromptId, role);
} catch (error) {
if (error instanceof MockExhaustedError) {
debugLogger.log(`[FallbackContentGenerator] Exhausted primary generator for generateContent. Falling back.`);
this.onFallback?.('generateContent');
return this.fallback.generateContent(request, userPromptId, role);
}
throw error;
}
}
async generateContentStream(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> {
try {
return await this.primary.generateContentStream(request, userPromptId, role);
} catch (error) {
if (error instanceof MockExhaustedError) {
debugLogger.log(`[FallbackContentGenerator] Exhausted primary generator for generateContentStream. Falling back.`);
this.onFallback?.('generateContentStream');
return this.fallback.generateContentStream(request, userPromptId, role);
}
throw error;
}
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
try {
if (!this.primary.countTokens) {
throw new MockExhaustedError('countTokens');
}
return await this.primary.countTokens(request);
} catch (error) {
if (error instanceof MockExhaustedError && this.fallback.countTokens) {
debugLogger.log(`[FallbackContentGenerator] Exhausted primary generator for countTokens. Falling back.`);
this.onFallback?.('countTokens');
return this.fallback.countTokens(request);
}
throw error;
}
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
try {
if (!this.primary.embedContent) {
throw new MockExhaustedError('embedContent');
}
return await this.primary.embedContent(request);
} catch (error) {
if (error instanceof MockExhaustedError && this.fallback.embedContent) {
debugLogger.log(`[FallbackContentGenerator] Exhausted primary generator for embedContent. Falling back.`);
this.onFallback?.('embedContent');
return this.fallback.embedContent(request);
}
throw error;
}
}
}
@@ -168,6 +168,10 @@ export class LoggingContentGenerator implements ContentGenerator {
return this.wrapped.paidTier;
}
getSentRequests?(): GenerateContentParameters[] {
return this.wrapped.getSentRequests?.() || [];
}
private logApiRequest(
contents: Content[],
model: string,
@@ -39,6 +39,10 @@ export class RecordingContentGenerator implements ContentGenerator {
return this.realGenerator.userTierName;
}
getSentRequests?(): GenerateContentParameters[] {
return this.realGenerator.getSentRequests?.() || [];
}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
+57
View File
@@ -0,0 +1,57 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { GenerateContentResponse } from '@google/genai';
import type { FakeResponse } from './fakeContentGenerator.js';
export type FakeRequest = { method: 'userText'; text: string };
export type ScriptItem = FakeResponse | FakeRequest;
export function mockGenerateContentStreamText(text: string): FakeResponse {
return {
method: 'generateContentStream',
response: [
{
candidates: [{ content: { parts: [{ text }] }, finishReason: 'STOP' }],
},
] as GenerateContentResponse[],
};
}
export function mockGenerateContentText(text: string): FakeResponse {
return {
method: 'generateContent',
response: {
candidates: [{ content: { parts: [{ text }] }, finishReason: 'STOP' }],
} as GenerateContentResponse,
};
}
export function userText(text: string): FakeRequest {
return { method: 'userText', text };
}
export function isFakeResponse(item: ScriptItem): item is FakeResponse {
return item.method !== 'userText';
}
export function isFakeRequest(item: ScriptItem): item is FakeRequest {
return item.method === 'userText';
}
/**
* Extracts all FakeRequests from a script array and maps them to their string text.
*/
export function extractUserPrompts(script: ScriptItem[]): string[] {
return script.filter(isFakeRequest).map((req) => req.text);
}
/**
* Extracts all FakeResponses from a script array.
*/
export function extractFakeResponses(script: ScriptItem[]): FakeResponse[] {
return script.filter(isFakeResponse);
}
+4
View File
@@ -24,6 +24,7 @@ export * from './config/extensions/integrityTypes.js';
export * from './billing/index.js';
export * from './confirmation-bus/types.js';
export * from './confirmation-bus/message-bus.js';
export * from './safety/conseca/conseca.js';
// Export Commands logic
export * from './commands/extensions.js';
@@ -36,6 +37,9 @@ export * from './commands/types.js';
export * from './core/baseLlmClient.js';
export * from './core/client.js';
export * from './core/contentGenerator.js';
export * from './core/fakeContentGenerator.js';
export * from './core/fallbackContentGenerator.js';
export * from './core/scriptUtils.js';
export * from './core/loggingContentGenerator.js';
export * from './core/geminiChat.js';
export * from './core/logger.js';
@@ -335,8 +335,10 @@ describe('ToolExecutor', () => {
it('should truncate large shell output', async () => {
// 1. Setup Config for Truncation
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
const mockTool = new MockTool({ name: SHELL_TOOL_NAME });
const invocation = mockTool.build({});
@@ -396,8 +398,10 @@ describe('ToolExecutor', () => {
it('should truncate large MCP tool output with single text Part', async () => {
// 1. Setup Config for Truncation
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
const mcpToolName = 'get_big_text';
const messageBus = createMockMessageBus();
@@ -440,8 +444,9 @@ describe('ToolExecutor', () => {
});
// 4. Verify Truncation Logic
const stringifiedLongText = JSON.stringify([{ text: longText }], null, 2);
expect(fileUtils.saveTruncatedToolOutput).toHaveBeenCalledWith(
longText,
stringifiedLongText,
mcpToolName,
'call-mcp-trunc',
expect.any(String),
@@ -449,7 +454,7 @@ describe('ToolExecutor', () => {
);
expect(fileUtils.formatTruncatedToolOutput).toHaveBeenCalledWith(
longText,
stringifiedLongText,
'/tmp/truncated_output.txt',
10,
);
@@ -460,8 +465,9 @@ describe('ToolExecutor', () => {
}
});
it('should not truncate MCP tool output with multiple Parts', async () => {
it('should truncate MCP tool output with multiple Parts', async () => {
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
const messageBus = createMockMessageBus();
const mcpTool = new DiscoveredMCPTool(
@@ -501,9 +507,26 @@ describe('ToolExecutor', () => {
onUpdateToolCall: vi.fn(),
});
// Should NOT have been truncated
expect(fileUtils.saveTruncatedToolOutput).not.toHaveBeenCalled();
expect(fileUtils.formatTruncatedToolOutput).not.toHaveBeenCalled();
const longText1 = 'This is long text that exceeds the threshold.';
const stringifiedLongText = JSON.stringify(
[{ text: longText1 }, { text: 'second part' }],
null,
2,
);
// Should HAVE been truncated now
expect(fileUtils.saveTruncatedToolOutput).toHaveBeenCalledWith(
stringifiedLongText,
'get_big_text',
'call-mcp-multi',
expect.any(String),
'test-session-id',
);
expect(fileUtils.formatTruncatedToolOutput).toHaveBeenCalledWith(
stringifiedLongText,
'/tmp/truncated_output.txt',
10,
);
expect(result.status).toBe(CoreToolCallStatus.Success);
});
@@ -712,8 +735,10 @@ describe('ToolExecutor', () => {
it('should truncate large shell output even on cancellation', async () => {
// 1. Setup Config for Truncation
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
const mockTool = new MockTool({ name: SHELL_TOOL_NAME });
const invocation = mockTool.build({});
+9 -90
View File
@@ -6,8 +6,6 @@
import {
ToolErrorType,
ToolOutputTruncatedEvent,
logToolOutputTruncated,
runInDevTraceSpan,
type ToolCallRequestInfo,
type ToolCallResponseInfo,
@@ -19,11 +17,10 @@ import {
import { isAbortError } from '../utils/errors.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { ToolOutputDistillationService } from '../services/toolDistillationService.js';
import { ShellToolInvocation } from '../tools/shell.js';
import { executeToolWithHooks } from '../core/coreToolHookTriggers.js';
import {
saveTruncatedToolOutput,
formatTruncatedToolOutput,
} from '../utils/fileUtils.js';
import { convertToFunctionResponse } from '../utils/generateContentResponseUtilities.js';
import {
CoreToolCallStatus,
@@ -180,90 +177,12 @@ export class ToolExecutor {
call: ToolCall,
content: PartListUnion,
): Promise<{ truncatedContent: PartListUnion; outputFile?: string }> {
const toolName = call.request.name;
const callId = call.request.callId;
let outputFile: string | undefined;
if (typeof content === 'string' && toolName === SHELL_TOOL_NAME) {
const threshold = this.config.getTruncateToolOutputThreshold();
if (threshold > 0 && content.length > threshold) {
const originalContentLength = content.length;
const { outputFile: savedPath } = await saveTruncatedToolOutput(
content,
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.context.promptId,
);
outputFile = savedPath;
const truncatedContent = formatTruncatedToolOutput(
content,
outputFile,
threshold,
);
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(call.request.prompt_id, {
toolName,
originalContentLength,
truncatedContentLength: truncatedContent.length,
threshold,
}),
);
return { truncatedContent, outputFile };
}
} else if (
Array.isArray(content) &&
content.length === 1 &&
'tool' in call &&
call.tool instanceof DiscoveredMCPTool
) {
const firstPart = content[0];
if (typeof firstPart === 'object' && typeof firstPart.text === 'string') {
const textContent = firstPart.text;
const threshold = this.config.getTruncateToolOutputThreshold();
if (threshold > 0 && textContent.length > threshold) {
const originalContentLength = textContent.length;
const { outputFile: savedPath } = await saveTruncatedToolOutput(
textContent,
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.context.promptId,
);
outputFile = savedPath;
const truncatedText = formatTruncatedToolOutput(
textContent,
outputFile,
threshold,
);
// We need to return a NEW array to avoid mutating the original toolResult if it matters,
// though here we are creating the response so it's probably fine to mutate or return new.
const truncatedContent: Part[] = [
{ ...firstPart, text: truncatedText },
];
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(call.request.prompt_id, {
toolName,
originalContentLength,
truncatedContentLength: truncatedText.length,
threshold,
}),
);
return { truncatedContent, outputFile };
}
}
}
return { truncatedContent: content, outputFile };
const distiller = new ToolOutputDistillationService(
this.config,
this.context.geminiClient,
this.context.promptId,
);
return distiller.distill(call.request.name, call.request.callId, content);
}
private async createCancelledResult(
@@ -0,0 +1,203 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
LlmRole,
ToolOutputTruncatedEvent,
logToolOutputTruncated,
debugLogger,
type Config,
} from '../index.js';
import type { PartListUnion } from '@google/genai';
import { type GeminiClient } from '../core/client.js';
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
import {
saveTruncatedToolOutput,
formatTruncatedToolOutput,
} from '../utils/fileUtils.js';
import {
READ_FILE_TOOL_NAME,
READ_MANY_FILES_TOOL_NAME,
} from '../tools/tool-names.js';
export interface DistilledToolOutput {
truncatedContent: PartListUnion;
outputFile?: string;
}
export class ToolOutputDistillationService {
constructor(
private readonly config: Config,
private readonly geminiClient: GeminiClient,
private readonly promptId: string,
) {}
/**
* Distills a tool's output if it exceeds configured length thresholds, preserving
* the agent's context window. This includes saving the raw output to disk, replacing
* the output with a truncated placeholder, and optionally summarizing the output
* via a secondary LLM call if the output is massively oversized.
*/
async distill(
toolName: string,
callId: string,
content: PartListUnion,
): Promise<DistilledToolOutput> {
// Explicitly bypass escape hatches that natively handle large outputs
if (this.isExemptFromDistillation(toolName)) {
return { truncatedContent: content };
}
const threshold = this.config.getTruncateToolOutputThreshold();
if (threshold <= 0) {
return { truncatedContent: content };
}
const originalContentLength = this.calculateContentLength(content);
if (originalContentLength > threshold) {
return this.performDistillation(
toolName,
callId,
content,
originalContentLength,
threshold,
);
}
return { truncatedContent: content };
}
private isExemptFromDistillation(toolName: string): boolean {
return (
toolName === READ_FILE_TOOL_NAME || toolName === READ_MANY_FILES_TOOL_NAME
);
}
private calculateContentLength(content: PartListUnion): number {
if (typeof content === 'string') {
return content.length;
}
if (Array.isArray(content)) {
return content.reduce((acc, part) => {
if (
typeof part === 'object' &&
part !== null &&
'text' in part &&
typeof part.text === 'string'
) {
return acc + part.text.length;
}
return acc;
}, 0);
}
return 0;
}
private stringifyContent(content: PartListUnion): string {
return typeof content === 'string'
? content
: JSON.stringify(content, null, 2);
}
private async performDistillation(
toolName: string,
callId: string,
content: PartListUnion,
originalContentLength: number,
threshold: number,
): Promise<DistilledToolOutput> {
const stringifiedContent = this.stringifyContent(content);
// Save the raw, untruncated string to disk for human review
const { outputFile: savedPath } = await saveTruncatedToolOutput(
stringifiedContent,
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.promptId,
);
let truncatedText = formatTruncatedToolOutput(
stringifiedContent,
savedPath,
threshold,
);
// If the output is massively oversized, attempt to generate a structural map
const summarizationThreshold = threshold * 1.5;
if (originalContentLength > summarizationThreshold) {
const summaryText = await this.generateStructuralMap(
toolName,
stringifiedContent,
Math.floor(summarizationThreshold),
);
if (summaryText) {
truncatedText += `\n\n--- Structural Map of Truncated Content ---\n${summaryText}`;
}
}
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(this.promptId, {
toolName,
originalContentLength,
truncatedContentLength: truncatedText.length,
threshold,
}),
);
return {
truncatedContent:
typeof content === 'string' ? truncatedText : [{ text: truncatedText }],
outputFile: savedPath,
};
}
/**
* Calls a fast, internal model (Flash-Lite) to provide a high-level summary
* of the truncated content's structure.
*/
private async generateStructuralMap(
toolName: string,
stringifiedContent: string,
maxPreviewLen: number,
): Promise<string | undefined> {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 15000); // 15s timeout
const promptText = `The following output from the tool '${toolName}' is extremely large and has been truncated. Please provide a very brief, high-level structural map of its contents (e.g., key sections, JSON schema outline, or line number ranges for major components). Keep the summary under 10 lines. Do not attempt to summarize the specific data values, just the structure so another agent knows what is inside.
Output to summarize:
${stringifiedContent.slice(0, maxPreviewLen)}...`;
const summaryResponse = await this.geminiClient.generateContent(
{
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
overrideScope: 'internal-summarizer',
},
[{ parts: [{ text: promptText }] }],
controller.signal,
LlmRole.MAIN,
);
clearTimeout(timeoutId);
return summaryResponse.candidates?.[0]?.content?.parts?.[0]?.text;
} catch (e) {
// Fail gracefully, summarization is a progressive enhancement
debugLogger.debug(
'Failed to generate structural map for truncated output:',
e,
);
return undefined;
}
}
}
+13 -47
View File
@@ -21,7 +21,6 @@ import { getErrorMessage } from '../utils/errors.js';
import { ApprovalMode } from '../policy/types.js';
import { getResponseText } from '../utils/partUtils.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { truncateString } from '../utils/textUtils.js';
import { convert } from 'html-to-text';
import {
logWebFetchFallbackAttempt,
@@ -40,11 +39,10 @@ import { LRUCache } from 'mnemonist';
import type { AgentLoopContext } from '../config/agent-loop-context.js';
const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 250000;
const MAX_EXPERIMENTAL_FETCH_SIZE = 10 * 1024 * 1024; // 10MB
const USER_AGENT =
'Mozilla/5.0 (compatible; Google-Gemini-CLI/1.0; +https://github.com/google-gemini/gemini-cli)';
const TRUNCATION_WARNING = '\n\n... [Content truncated due to size limit] ...';
// Rate limiting configuration
const RATE_LIMIT_WINDOW_MS = 60000; // 1 minute
@@ -331,9 +329,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
textContent = rawContent;
}
// Cap at MAX_CONTENT_LENGTH initially to avoid excessive memory usage
// before the global budget allocation.
return truncateString(textContent, MAX_CONTENT_LENGTH, '');
return textContent;
}
private filterAndValidateUrls(urls: string[]): {
@@ -399,28 +395,10 @@ class WebFetchToolInvocation extends BaseToolInvocation<
};
}
// Smart Budget Allocation (Water-filling algorithm) for successes
const sortedSuccesses = [...successes].sort(
(a, b) => a.content.length - b.content.length,
);
let remainingBudget = MAX_CONTENT_LENGTH;
let remainingUrls = sortedSuccesses.length;
const finalContentsByUrl = new Map<string, string>();
for (const success of sortedSuccesses) {
const fairShare = Math.floor(remainingBudget / remainingUrls);
const allocated = Math.min(success.content.length, fairShare);
const truncated = truncateString(
success.content,
allocated,
TRUNCATION_WARNING,
);
finalContentsByUrl.set(success.url, truncated);
remainingBudget -= truncated.length;
remainingUrls--;
for (const success of successes) {
finalContentsByUrl.set(success.url, success.content);
}
const aggregatedContent = uniqueUrls
@@ -663,7 +641,7 @@ ${aggregatedContent}
});
const errorContent = `Request failed with status ${status}
Headers: ${JSON.stringify(headers, null, 2)}
Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response truncated] ...')}`;
Response: ${rawResponseText}`;
debugLogger.error(
`[WebFetchTool] Experimental fetch failed with status ${status} for ${url}`,
);
@@ -679,11 +657,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
lowContentType.includes('text/plain') ||
lowContentType.includes('application/json')
) {
const text = truncateString(
bodyBuffer.toString('utf8'),
MAX_CONTENT_LENGTH,
TRUNCATION_WARNING,
);
const text = bodyBuffer.toString('utf8');
return {
llmContent: text,
returnDisplay: `Fetched ${contentType} content from ${url}`,
@@ -692,16 +666,12 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
if (lowContentType.includes('text/html')) {
const html = bodyBuffer.toString('utf8');
const textContent = truncateString(
convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: false, baseUrl: url } },
],
}),
MAX_CONTENT_LENGTH,
TRUNCATION_WARNING,
);
const textContent = convert(html, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: false, baseUrl: url } },
],
});
return {
llmContent: textContent,
returnDisplay: `Fetched and converted HTML content from ${url}`,
@@ -726,11 +696,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
}
// Fallback for unknown types - try as text
const text = truncateString(
bodyBuffer.toString('utf8'),
MAX_CONTENT_LENGTH,
TRUNCATION_WARNING,
);
const text = bodyBuffer.toString('utf8');
return {
llmContent: text,
returnDisplay: `Fetched ${contentType || 'unknown'} content from ${url}`,