feat(core): improve shell execution service reliability (#10607)

This commit is contained in:
Gal Zahavi
2025-10-10 15:14:37 -07:00
committed by GitHub
parent 37678acb1a
commit 265d39f337
4 changed files with 292 additions and 125 deletions

View File

@@ -219,7 +219,7 @@ describe('useShellCommandProcessor', () => {
vi.useRealTimers();
});
it('should throttle pending UI updates for text streams (non-interactive)', async () => {
it('should update UI for text streams (non-interactive)', async () => {
const { result } = renderProcessorHook();
act(() => {
result.current.handleShellCommand(
@@ -243,61 +243,43 @@ describe('useShellCommandProcessor', () => {
);
// Wait for the async PID update to happen.
// Call 1: Initial, Call 2: PID update
await vi.waitFor(() => {
// It's called once for initial, and once for the PID update.
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2);
});
// Simulate rapid output
// Get the state after the PID update to feed into the stream updaters
const pidUpdateFn = setPendingHistoryItemMock.mock.calls[1][0];
const initialState = setPendingHistoryItemMock.mock.calls[0][0];
const stateAfterPid = pidUpdateFn(initialState);
// Simulate first output chunk
act(() => {
mockShellOutputCallback({
type: 'data',
chunk: 'hello',
});
});
// The count should still be 2, as throttling is in effect.
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2);
// A UI update should have occurred.
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(3);
// Simulate more rapid output
const streamUpdateFn1 = setPendingHistoryItemMock.mock.calls[2][0];
const stateAfterStream1 = streamUpdateFn1(stateAfterPid);
expect(stateAfterStream1.tools[0].resultDisplay).toBe('hello');
// Simulate second output chunk
act(() => {
mockShellOutputCallback({
type: 'data',
chunk: ' world',
});
});
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(2);
// Another UI update should have occurred.
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(4);
// Advance time, but the update won't happen until the next event
await act(async () => {
await vi.advanceTimersByTimeAsync(OUTPUT_UPDATE_INTERVAL_MS + 1);
});
// Trigger one more event to cause the throttled update to fire.
act(() => {
mockShellOutputCallback({
type: 'data',
chunk: '',
});
});
// Now the cumulative update should have occurred.
// Call 1: Initial, Call 2: PID update, Call 3: Throttled stream update
expect(setPendingHistoryItemMock).toHaveBeenCalledTimes(3);
const streamUpdateFn = setPendingHistoryItemMock.mock.calls[2][0];
if (!streamUpdateFn || typeof streamUpdateFn !== 'function') {
throw new Error(
'setPendingHistoryItem was not called with a stream updater function',
);
}
// Get the state after the PID update to feed into the stream updater
const pidUpdateFn = setPendingHistoryItemMock.mock.calls[1][0];
const initialState = setPendingHistoryItemMock.mock.calls[0][0];
const stateAfterPid = pidUpdateFn(initialState);
const stateAfterStream = streamUpdateFn(stateAfterPid);
expect(stateAfterStream.tools[0].resultDisplay).toBe('hello world');
const streamUpdateFn2 = setPendingHistoryItemMock.mock.calls[3][0];
const stateAfterStream2 = streamUpdateFn2(stateAfterStream1);
expect(stateAfterStream2.tools[0].resultDisplay).toBe('hello world');
});
it('should show binary progress messages correctly', async () => {

View File

@@ -109,7 +109,6 @@ export const useShellCommandProcessor = (
const executeCommand = async (
resolve: (value: void | PromiseLike<void>) => void,
) => {
let lastUpdateTime = Date.now();
let cumulativeStdout: string | AnsiOutput = '';
let isBinaryStream = false;
let binaryBytesReceived = 0;
@@ -168,6 +167,7 @@ export const useShellCommandProcessor = (
typeof cumulativeStdout === 'string'
) {
cumulativeStdout += event.chunk;
shouldUpdate = true;
}
break;
case 'binary_detected':
@@ -178,6 +178,7 @@ export const useShellCommandProcessor = (
case 'binary_progress':
isBinaryStream = true;
binaryBytesReceived = event.bytesReceived;
shouldUpdate = true;
break;
default: {
throw new Error('An unhandled ShellOutputEvent was found.');
@@ -200,10 +201,7 @@ export const useShellCommandProcessor = (
}
// Throttle pending UI updates, but allow forced updates.
if (
shouldUpdate ||
Date.now() - lastUpdateTime > OUTPUT_UPDATE_INTERVAL_MS
) {
if (shouldUpdate) {
setPendingHistoryItem((prevItem) => {
if (prevItem?.type === 'tool_group') {
return {
@@ -217,7 +215,6 @@ export const useShellCommandProcessor = (
}
return prevItem;
});
lastUpdateTime = Date.now();
}
},
abortSignal,

View File

@@ -153,7 +153,7 @@ describe('ShellExecutionService', () => {
simulation: (
ptyProcess: typeof mockPtyProcess,
ac: AbortController,
) => void,
) => void | Promise<void>,
config = shellExecutionConfig,
) => {
const abortController = new AbortController();
@@ -167,7 +167,7 @@ describe('ShellExecutionService', () => {
);
await new Promise((resolve) => process.nextTick(resolve));
simulation(mockPtyProcess, abortController);
await simulation(mockPtyProcess, abortController);
const result = await handle.result;
return { result, handle, abortController };
};
@@ -356,6 +356,63 @@ describe('ShellExecutionService', () => {
expect(result.aborted).toBe(true);
// The process kill is mocked, so we just check that the flag is set.
});
it('should send SIGTERM and then SIGKILL on abort', async () => {
const sigkillPromise = new Promise<void>((resolve) => {
mockProcessKill.mockImplementation((pid, signal) => {
if (signal === 'SIGKILL' && pid === -mockPtyProcess.pid) {
resolve();
}
return true;
});
});
const { result } = await simulateExecution(
'long-running-process',
async (pty, abortController) => {
abortController.abort();
await sigkillPromise; // Wait for SIGKILL to be sent before exiting.
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: 9 });
},
);
expect(result.aborted).toBe(true);
// Verify the calls were made in the correct order.
const killCalls = mockProcessKill.mock.calls;
const sigtermCallIndex = killCalls.findIndex(
(call) => call[0] === -mockPtyProcess.pid && call[1] === 'SIGTERM',
);
const sigkillCallIndex = killCalls.findIndex(
(call) => call[0] === -mockPtyProcess.pid && call[1] === 'SIGKILL',
);
expect(sigtermCallIndex).toBe(0);
expect(sigkillCallIndex).toBe(1);
expect(sigtermCallIndex).toBeLessThan(sigkillCallIndex);
expect(result.signal).toBe(9);
});
it('should resolve without waiting for the processing chain on abort', async () => {
const { result } = await simulateExecution(
'long-output',
(pty, abortController) => {
// Simulate a lot of data being in the queue to be processed
for (let i = 0; i < 1000; i++) {
pty.onData.mock.calls[0][0]('some data');
}
abortController.abort();
pty.onExit.mock.calls[0][0]({ exitCode: 1, signal: null });
},
);
// The main assertion here is implicit: the `await` for the result above
// should complete without timing out. This proves that the resolution
// was not blocked by the long chain of data processing promises,
// which is the desired behavior on abort.
expect(result.aborted).toBe(true);
});
});
describe('Binary Output', () => {
@@ -633,6 +690,36 @@ describe('ShellExecutionService child_process fallback', () => {
expect(result.output.trim()).toBe('');
expect(onOutputEventMock).not.toHaveBeenCalled();
});
it('should truncate stdout using a sliding window and show a warning', async () => {
const MAX_SIZE = 16 * 1024 * 1024;
const chunk1 = 'a'.repeat(MAX_SIZE / 2 - 5);
const chunk2 = 'b'.repeat(MAX_SIZE / 2 - 5);
const chunk3 = 'c'.repeat(20);
const { result } = await simulateExecution('large-output', (cp) => {
cp.stdout?.emit('data', Buffer.from(chunk1));
cp.stdout?.emit('data', Buffer.from(chunk2));
cp.stdout?.emit('data', Buffer.from(chunk3));
cp.emit('exit', 0, null);
});
const truncationMessage =
'[GEMINI_CLI_WARNING: Output truncated. The buffer is limited to 16MB.]';
expect(result.output).toContain(truncationMessage);
const outputWithoutMessage = result.output
.substring(0, result.output.indexOf(truncationMessage))
.trimEnd();
expect(outputWithoutMessage.length).toBe(MAX_SIZE);
const expectedStart = (chunk1 + chunk2 + chunk3).slice(-MAX_SIZE);
expect(
outputWithoutMessage.startsWith(expectedStart.substring(0, 10)),
).toBe(true);
expect(outputWithoutMessage.endsWith('c'.repeat(20))).toBe(true);
}, 20000);
});
describe('Failed Execution', () => {

View File

@@ -21,6 +21,7 @@ import {
const { Terminal } = pkg;
const SIGKILL_TIMEOUT_MS = 200;
const MAX_CHILD_PROCESS_BUFFER_SIZE = 16 * 1024 * 1024; // 16MB
/** A structured result from a shell command execution. */
export interface ShellExecutionResult {
@@ -150,6 +151,36 @@ export class ShellExecutionService {
);
}
private static appendAndTruncate(
currentBuffer: string,
chunk: string,
maxSize: number,
): { newBuffer: string; truncated: boolean } {
const chunkLength = chunk.length;
const currentLength = currentBuffer.length;
const newTotalLength = currentLength + chunkLength;
if (newTotalLength <= maxSize) {
return { newBuffer: currentBuffer + chunk, truncated: false };
}
// Truncation is needed.
if (chunkLength >= maxSize) {
// The new chunk is larger than or equal to the max buffer size.
// The new buffer will be the tail of the new chunk.
return {
newBuffer: chunk.substring(chunkLength - maxSize),
truncated: true,
};
}
// The combined buffer exceeds the max size, but the new chunk is smaller than it.
// We need to truncate the current buffer from the beginning to make space.
const charsToTrim = newTotalLength - maxSize;
const truncatedBuffer = currentBuffer.substring(charsToTrim);
return { newBuffer: truncatedBuffer + chunk, truncated: true };
}
private static childProcessFallback(
commandToExecute: string,
cwd: string,
@@ -179,6 +210,8 @@ export class ShellExecutionService {
let stdout = '';
let stderr = '';
let stdoutTruncated = false;
let stderrTruncated = false;
const outputChunks: Buffer[] = [];
let error: Error | null = null;
let exited = false;
@@ -215,9 +248,25 @@ export class ShellExecutionService {
const decodedChunk = decoder.decode(data, { stream: true });
if (stream === 'stdout') {
stdout += decodedChunk;
const { newBuffer, truncated } = this.appendAndTruncate(
stdout,
decodedChunk,
MAX_CHILD_PROCESS_BUFFER_SIZE,
);
stdout = newBuffer;
if (truncated) {
stdoutTruncated = true;
}
} else {
stderr += decodedChunk;
const { newBuffer, truncated } = this.appendAndTruncate(
stderr,
decodedChunk,
MAX_CHILD_PROCESS_BUFFER_SIZE,
);
stderr = newBuffer;
if (truncated) {
stderrTruncated = true;
}
}
}
};
@@ -229,9 +278,16 @@ export class ShellExecutionService {
const { finalBuffer } = cleanup();
// Ensure we don't add an extra newline if stdout already ends with one.
const separator = stdout.endsWith('\n') ? '' : '\n';
const combinedOutput =
let combinedOutput =
stdout + (stderr ? (stdout ? separator : '') + stderr : '');
if (stdoutTruncated || stderrTruncated) {
const truncationMessage = `\n[GEMINI_CLI_WARNING: Output truncated. The buffer is limited to ${
MAX_CHILD_PROCESS_BUFFER_SIZE / (1024 * 1024)
}MB.]`;
combinedOutput += truncationMessage;
}
const finalStrippedOutput = stripAnsi(combinedOutput).trim();
if (isStreamingRawContent) {
@@ -388,86 +444,99 @@ export class ShellExecutionService {
let hasStartedOutput = false;
let renderTimeout: NodeJS.Timeout | null = null;
const render = (finalRender = false) => {
if (renderTimeout) {
clearTimeout(renderTimeout);
const renderFn = () => {
renderTimeout = null;
if (!isStreamingRawContent) {
return;
}
const renderFn = () => {
if (!isStreamingRawContent) {
return;
}
if (!shellExecutionConfig.disableDynamicLineTrimming) {
if (!hasStartedOutput) {
const bufferText = getFullBufferText(headlessTerminal);
if (bufferText.trim().length === 0) {
return;
}
hasStartedOutput = true;
if (!shellExecutionConfig.disableDynamicLineTrimming) {
if (!hasStartedOutput) {
const bufferText = getFullBufferText(headlessTerminal);
if (bufferText.trim().length === 0) {
return;
}
hasStartedOutput = true;
}
}
let newOutput: AnsiOutput;
if (shellExecutionConfig.showColor) {
newOutput = serializeTerminalToObject(headlessTerminal);
} else {
const buffer = headlessTerminal.buffer.active;
const lines: AnsiOutput = [];
for (let y = 0; y < headlessTerminal.rows; y++) {
const line = buffer.getLine(buffer.viewportY + y);
const lineContent = line ? line.translateToString(true) : '';
lines.push([
{
text: lineContent,
bold: false,
italic: false,
underline: false,
dim: false,
inverse: false,
fg: '',
bg: '',
},
]);
}
newOutput = lines;
}
let lastNonEmptyLine = -1;
for (let i = newOutput.length - 1; i >= 0; i--) {
const line = newOutput[i];
if (
line
.map((segment) => segment.text)
.join('')
.trim().length > 0
) {
lastNonEmptyLine = i;
break;
}
}
const trimmedOutput = newOutput.slice(0, lastNonEmptyLine + 1);
const finalOutput = shellExecutionConfig.disableDynamicLineTrimming
? newOutput
: trimmedOutput;
// Using stringify for a quick deep comparison.
if (JSON.stringify(output) !== JSON.stringify(finalOutput)) {
output = finalOutput;
onOutputEvent({
type: 'data',
chunk: finalOutput,
});
}
};
if (finalRender) {
renderFn();
const buffer = headlessTerminal.buffer.active;
let newOutput: AnsiOutput;
if (shellExecutionConfig.showColor) {
newOutput = serializeTerminalToObject(headlessTerminal);
} else {
renderTimeout = setTimeout(renderFn, 17);
const lines: AnsiOutput = [];
for (let y = 0; y < headlessTerminal.rows; y++) {
const line = buffer.getLine(buffer.viewportY + y);
const lineContent = line ? line.translateToString(true) : '';
lines.push([
{
text: lineContent,
bold: false,
italic: false,
underline: false,
dim: false,
inverse: false,
fg: '',
bg: '',
},
]);
}
newOutput = lines;
}
let lastNonEmptyLine = -1;
for (let i = newOutput.length - 1; i >= 0; i--) {
const line = newOutput[i];
if (
line
.map((segment) => segment.text)
.join('')
.trim().length > 0
) {
lastNonEmptyLine = i;
break;
}
}
if (buffer.cursorY > lastNonEmptyLine) {
lastNonEmptyLine = buffer.cursorY;
}
const trimmedOutput = newOutput.slice(0, lastNonEmptyLine + 1);
const finalOutput = shellExecutionConfig.disableDynamicLineTrimming
? newOutput
: trimmedOutput;
// Using stringify for a quick deep comparison.
if (JSON.stringify(output) !== JSON.stringify(finalOutput)) {
output = finalOutput;
onOutputEvent({
type: 'data',
chunk: finalOutput,
});
}
};
const render = (finalRender = false) => {
if (finalRender) {
if (renderTimeout) {
clearTimeout(renderTimeout);
}
renderFn();
return;
}
if (renderTimeout) {
return;
}
renderTimeout = setTimeout(() => {
renderFn();
renderTimeout = null;
}, 68);
};
headlessTerminal.onScroll(() => {
@@ -503,6 +572,10 @@ export class ShellExecutionService {
if (isStreamingRawContent) {
const decodedChunk = decoder.decode(data, { stream: true });
if (decodedChunk.length === 0) {
resolve();
return;
}
isWriting = true;
headlessTerminal.write(decodedChunk, () => {
render();
@@ -535,7 +608,7 @@ export class ShellExecutionService {
abortSignal.removeEventListener('abort', abortHandler);
this.activePtys.delete(ptyProcess.pid);
processingChain.then(() => {
const finalize = () => {
render(true);
const finalBuffer = Buffer.concat(outputChunks);
@@ -551,6 +624,26 @@ export class ShellExecutionService {
(ptyInfo?.name as 'node-pty' | 'lydell-node-pty') ??
'node-pty',
});
};
if (abortSignal.aborted) {
finalize();
return;
}
const processingComplete = processingChain.then(() => 'processed');
const abortFired = new Promise<'aborted'>((res) => {
if (abortSignal.aborted) {
res('aborted');
return;
}
abortSignal.addEventListener('abort', () => res('aborted'), {
once: true,
});
});
Promise.race([processingComplete, abortFired]).then(() => {
finalize();
});
},
);
@@ -562,10 +655,18 @@ export class ShellExecutionService {
} else {
try {
// Kill the entire process group
process.kill(-ptyProcess.pid, 'SIGINT');
process.kill(-ptyProcess.pid, 'SIGTERM');
await new Promise((res) => setTimeout(res, SIGKILL_TIMEOUT_MS));
if (!exited) {
process.kill(-ptyProcess.pid, 'SIGKILL');
}
} catch (_e) {
// Fallback to killing just the process if the group kill fails
ptyProcess.kill('SIGINT');
ptyProcess.kill('SIGTERM');
await new Promise((res) => setTimeout(res, SIGKILL_TIMEOUT_MS));
if (!exited) {
ptyProcess.kill('SIGKILL');
}
}
}
}