feat(shell): enable interactive commands with virtual terminal (#6694)

This commit is contained in:
Gal Zahavi
2025-09-11 13:27:27 -07:00
committed by GitHub
parent 8969a232ec
commit 181898cb5d
43 changed files with 2345 additions and 324 deletions
+2 -41
View File
@@ -155,8 +155,7 @@ describe('ShellTool', () => {
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
{},
);
expect(result.llmContent).toContain('Background PIDs: 54322');
expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile);
@@ -183,8 +182,7 @@ describe('ShellTool', () => {
expect.any(Function),
mockAbortSignal,
false,
undefined,
undefined,
{},
);
});
@@ -296,43 +294,6 @@ describe('ShellTool', () => {
vi.useRealTimers();
});
it('should throttle text output updates', async () => {
const invocation = shellTool.build({ command: 'stream' });
const promise = invocation.execute(mockAbortSignal, updateOutputMock);
// First chunk, should be throttled.
mockShellOutputCallback({
type: 'data',
chunk: 'hello ',
});
expect(updateOutputMock).not.toHaveBeenCalled();
// Advance time past the throttle interval.
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
// Send a second chunk. THIS event triggers the update with the CUMULATIVE content.
mockShellOutputCallback({
type: 'data',
chunk: 'world',
});
// It should have been called once now with the combined output.
expect(updateOutputMock).toHaveBeenCalledOnce();
expect(updateOutputMock).toHaveBeenCalledWith('hello world');
resolveExecutionPromise({
rawOutput: Buffer.from(''),
output: '',
exitCode: 0,
signal: null,
error: null,
aborted: false,
pid: 12345,
executionMethod: 'child_process',
});
await promise;
});
it('should immediately show binary detection message and throttle progress', async () => {
const invocation = shellTool.build({ command: 'cat img' });
const promise = invocation.execute(mockAbortSignal, updateOutputMock);
+58 -57
View File
@@ -24,9 +24,13 @@ import {
} from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
import { summarizeToolOutput } from '../utils/summarizer.js';
import type { ShellOutputEvent } from '../services/shellExecutionService.js';
import type {
ShellExecutionConfig,
ShellOutputEvent,
} from '../services/shellExecutionService.js';
import { ShellExecutionService } from '../services/shellExecutionService.js';
import { formatMemoryUsage } from '../utils/formatters.js';
import type { AnsiOutput } from '../utils/terminalSerializer.js';
import {
getCommandRoots,
isCommandAllowed,
@@ -41,7 +45,7 @@ export interface ShellToolParams {
directory?: string;
}
class ShellToolInvocation extends BaseToolInvocation<
export class ShellToolInvocation extends BaseToolInvocation<
ShellToolParams,
ToolResult
> {
@@ -96,9 +100,9 @@ class ShellToolInvocation extends BaseToolInvocation<
async execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
terminalColumns?: number,
terminalRows?: number,
updateOutput?: (output: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
setPidCallback?: (pid: number) => void,
): Promise<ToolResult> {
const strippedCommand = stripShellWrapper(this.params.command);
@@ -131,63 +135,60 @@ class ShellToolInvocation extends BaseToolInvocation<
this.params.directory || '',
);
let cumulativeOutput = '';
let outputChunks: string[] = [cumulativeOutput];
let cumulativeOutput: string | AnsiOutput = '';
let lastUpdateTime = Date.now();
let isBinaryStream = false;
const { result: resultPromise } = await ShellExecutionService.execute(
commandToExecute,
cwd,
(event: ShellOutputEvent) => {
if (!updateOutput) {
return;
}
let currentDisplayOutput = '';
let shouldUpdate = false;
switch (event.type) {
case 'data':
if (isBinaryStream) break;
outputChunks.push(event.chunk);
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
cumulativeOutput = outputChunks.join('');
outputChunks = [cumulativeOutput];
currentDisplayOutput = cumulativeOutput;
shouldUpdate = true;
}
break;
case 'binary_detected':
isBinaryStream = true;
currentDisplayOutput =
'[Binary output detected. Halting stream...]';
shouldUpdate = true;
break;
case 'binary_progress':
isBinaryStream = true;
currentDisplayOutput = `[Receiving binary output... ${formatMemoryUsage(
event.bytesReceived,
)} received]`;
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
shouldUpdate = true;
}
break;
default: {
throw new Error('An unhandled ShellOutputEvent was found.');
const { result: resultPromise, pid } =
await ShellExecutionService.execute(
commandToExecute,
cwd,
(event: ShellOutputEvent) => {
if (!updateOutput) {
return;
}
}
if (shouldUpdate) {
updateOutput(currentDisplayOutput);
lastUpdateTime = Date.now();
}
},
signal,
this.config.getShouldUseNodePtyShell(),
terminalColumns,
terminalRows,
);
let shouldUpdate = false;
switch (event.type) {
case 'data':
if (isBinaryStream) break;
cumulativeOutput = event.chunk;
shouldUpdate = true;
break;
case 'binary_detected':
isBinaryStream = true;
cumulativeOutput =
'[Binary output detected. Halting stream...]';
shouldUpdate = true;
break;
case 'binary_progress':
isBinaryStream = true;
cumulativeOutput = `[Receiving binary output... ${formatMemoryUsage(
event.bytesReceived,
)} received]`;
if (Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS) {
shouldUpdate = true;
}
break;
default: {
throw new Error('An unhandled ShellOutputEvent was found.');
}
}
if (shouldUpdate) {
updateOutput(cumulativeOutput);
lastUpdateTime = Date.now();
}
},
signal,
this.config.getShouldUseNodePtyShell(),
shellExecutionConfig ?? {},
);
if (pid && setPidCallback) {
setPidCallback(pid);
}
const result = await resultPromise;
+10 -5
View File
@@ -7,7 +7,9 @@
import type { FunctionDeclaration, PartListUnion } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import type { DiffUpdateResult } from '../ide/ideContext.js';
import type { ShellExecutionConfig } from '../services/shellExecutionService.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import type { AnsiOutput } from '../utils/terminalSerializer.js';
/**
* Represents a validated and ready-to-execute tool call.
@@ -51,7 +53,8 @@ export interface ToolInvocation<
*/
execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
updateOutput?: (output: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
): Promise<TResult>;
}
@@ -79,7 +82,8 @@ export abstract class BaseToolInvocation<
abstract execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
updateOutput?: (output: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
): Promise<TResult>;
}
@@ -197,10 +201,11 @@ export abstract class DeclarativeTool<
async buildAndExecute(
params: TParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
updateOutput?: (output: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
): Promise<TResult> {
const invocation = this.build(params);
return invocation.execute(signal, updateOutput);
return invocation.execute(signal, updateOutput, shellExecutionConfig);
}
/**
@@ -432,7 +437,7 @@ export function hasCycleInSchema(schema: object): boolean {
return traverse(schema, new Set<string>(), new Set<string>());
}
export type ToolResultDisplay = string | FileDiff;
export type ToolResultDisplay = string | FileDiff | AnsiOutput;
export interface FileDiff {
fileDiff: string;