diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 9f81fc9cf0..be14b3b195 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -100,6 +100,12 @@ interface BackgroundExecutionData { initialOutput?: string; } +interface BackgroundedShellInfo { + pid: number; + command: string; + initialOutput: string; +} + enum StreamProcessingStatus { Completed, UserCancelled, @@ -125,6 +131,28 @@ function isBackgroundExecutionData( ); } +function getBackgroundedShellInfo( + toolCall: TrackedCompletedToolCall | TrackedCancelledToolCall, +): BackgroundedShellInfo | undefined { + if (toolCall.request.name !== SHELL_COMMAND_NAME) { + return undefined; + } + + const response = toolCall.response as ToolResponseWithParts; + const rawData = response?.data; + const data = isBackgroundExecutionData(rawData) ? rawData : undefined; + + if (!data?.pid) { + return undefined; + } + + return { + pid: data.pid, + command: data.command ?? 'shell', + initialOutput: data.initialOutput ?? '', + }; +} + function showCitations(settings: LoadedSettings): boolean { const enabled = settings.merged.ui.showCitations; if (enabled !== undefined) { @@ -315,11 +343,11 @@ export const useGeminiStream = ( const activeToolExecutionId = useMemo(() => { const executingShellTool = toolCalls.find( - (tc) => - tc.status === 'executing' && tc.request.name === 'run_shell_command', + (tc): tc is TrackedExecutingToolCall => + tc.status === CoreToolCallStatus.Executing && + tc.request.name === SHELL_COMMAND_NAME, ); - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - return (executingShellTool as TrackedExecutingToolCall | undefined)?.pid; + return executingShellTool?.pid; }, [toolCalls]); const onExec = useCallback(async (done: Promise) => { @@ -1653,26 +1681,16 @@ export const useGeminiStream = ( !processedMemoryToolsRef.current.has(t.request.callId), ); - // Handle backgrounded shell tools - completedAndReadyToSubmitTools.forEach((t) => { - const isShell = t.request.name === 'run_shell_command'; - // Access result from the tracked tool call response - const response = t.response as ToolResponseWithParts; - const rawData = response?.data; - const data = isBackgroundExecutionData(rawData) ? rawData : undefined; - - // Use data.pid for shell commands moved to the background. - const pid = data?.pid; - - if (isShell && pid) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const command = (data?.['command'] as string) ?? 'shell'; - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const initialOutput = (data?.['initialOutput'] as string) ?? ''; - - registerBackgroundShell(pid, command, initialOutput); + for (const toolCall of completedAndReadyToSubmitTools) { + const backgroundedShell = getBackgroundedShellInfo(toolCall); + if (backgroundedShell) { + registerBackgroundShell( + backgroundedShell.pid, + backgroundedShell.command, + backgroundedShell.initialOutput, + ); } - }); + } if (newSuccessfulMemorySaves.length > 0) { // Perform the refresh only if there are new ones. diff --git a/packages/core/src/services/executionLifecycleService.test.ts b/packages/core/src/services/executionLifecycleService.test.ts index 9e2ee9007a..2702da9488 100644 --- a/packages/core/src/services/executionLifecycleService.test.ts +++ b/packages/core/src/services/executionLifecycleService.test.ts @@ -11,56 +11,6 @@ import { type ExecutionResult, } from './executionLifecycleService.js'; -const BASE_VIRTUAL_ID = 2_000_000_000; - -function resetLifecycleState() { - ( - ExecutionLifecycleService as unknown as { - activeExecutions: Map; - activeResolvers: Map; - activeListeners: Map; - exitedExecutionInfo: Map; - nextVirtualExecutionId: number; - } - ).activeExecutions.clear(); - ( - ExecutionLifecycleService as unknown as { - activeExecutions: Map; - activeResolvers: Map; - activeListeners: Map; - exitedExecutionInfo: Map; - nextVirtualExecutionId: number; - } - ).activeResolvers.clear(); - ( - ExecutionLifecycleService as unknown as { - activeExecutions: Map; - activeResolvers: Map; - activeListeners: Map; - exitedExecutionInfo: Map; - nextVirtualExecutionId: number; - } - ).activeListeners.clear(); - ( - ExecutionLifecycleService as unknown as { - activeExecutions: Map; - activeResolvers: Map; - activeListeners: Map; - exitedExecutionInfo: Map; - nextVirtualExecutionId: number; - } - ).exitedExecutionInfo.clear(); - ( - ExecutionLifecycleService as unknown as { - activeExecutions: Map; - activeResolvers: Map; - activeListeners: Map; - exitedExecutionInfo: Map; - nextVirtualExecutionId: number; - } - ).nextVirtualExecutionId = BASE_VIRTUAL_ID; -} - function createResult( overrides: Partial = {}, ): ExecutionResult { @@ -79,11 +29,11 @@ function createResult( describe('ExecutionLifecycleService', () => { beforeEach(() => { - resetLifecycleState(); + ExecutionLifecycleService.resetForTest(); }); it('completes virtual executions in the foreground and notifies exit subscribers', async () => { - const handle = ExecutionLifecycleService.createExecution(); + const handle = ExecutionLifecycleService.createVirtualExecution(); if (handle.pid === undefined) { throw new Error('Expected virtual execution ID.'); } @@ -93,7 +43,9 @@ describe('ExecutionLifecycleService', () => { ExecutionLifecycleService.appendOutput(handle.pid, 'Hello'); ExecutionLifecycleService.appendOutput(handle.pid, ' World'); - ExecutionLifecycleService.completeExecution(handle.pid, { exitCode: 0 }); + ExecutionLifecycleService.completeVirtualExecution(handle.pid, { + exitCode: 0, + }); const result = await handle.result; expect(result.output).toBe('Hello World'); @@ -108,7 +60,7 @@ describe('ExecutionLifecycleService', () => { }); it('supports backgrounding virtual executions and continues streaming updates', async () => { - const handle = ExecutionLifecycleService.createExecution(); + const handle = ExecutionLifecycleService.createVirtualExecution(); if (handle.pid === undefined) { throw new Error('Expected virtual execution ID.'); } @@ -134,7 +86,9 @@ describe('ExecutionLifecycleService', () => { expect(backgroundResult.output).toBe('Chunk 1'); ExecutionLifecycleService.appendOutput(handle.pid, '\nChunk 2'); - ExecutionLifecycleService.completeExecution(handle.pid, { exitCode: 0 }); + ExecutionLifecycleService.completeVirtualExecution(handle.pid, { + exitCode: 0, + }); await vi.waitFor(() => { expect(chunks.join('')).toContain('Chunk 2'); @@ -147,7 +101,7 @@ describe('ExecutionLifecycleService', () => { it('kills virtual executions and resolves with aborted result', async () => { const onKill = vi.fn(); - const handle = ExecutionLifecycleService.createExecution('', onKill); + const handle = ExecutionLifecycleService.createVirtualExecution('', onKill); if (handle.pid === undefined) { throw new Error('Expected virtual execution ID.'); } @@ -164,7 +118,6 @@ describe('ExecutionLifecycleService', () => { it('manages external executions through registration hooks', async () => { const writeInput = vi.fn(); - const terminate = vi.fn(); const isActive = vi.fn().mockReturnValue(true); const exitListener = vi.fn(); const chunks: string[] = []; @@ -177,7 +130,6 @@ describe('ExecutionLifecycleService', () => { getBackgroundOutput: () => output, getSubscriptionSnapshot: () => output, writeInput, - kill: terminate, isActive, }, ); @@ -203,7 +155,7 @@ describe('ExecutionLifecycleService', () => { expect(backgroundResult.output).toBe('seed +delta'); expect(backgroundResult.executionMethod).toBe('child_process'); - ExecutionLifecycleService.finalizeExecution( + ExecutionLifecycleService.completeWithResult( 4321, createResult({ pid: 4321, @@ -222,13 +174,84 @@ describe('ExecutionLifecycleService', () => { expect(lateExit).toHaveBeenCalledWith(0, undefined); unsubscribe(); + }); - const killHandle = ExecutionLifecycleService.registerExecution(4322, { + it('supports late subscription catch-up after backgrounding an external execution', async () => { + let output = 'seed'; + const onExit = vi.fn(); + const handle = ExecutionLifecycleService.registerExecution(4322, { executionMethod: 'child_process', + getBackgroundOutput: () => output, + getSubscriptionSnapshot: () => output, + }); + + ExecutionLifecycleService.onExit(4322, onExit); + ExecutionLifecycleService.background(4322); + + const backgroundResult = await handle.result; + expect(backgroundResult.backgrounded).toBe(true); + expect(backgroundResult.output).toBe('seed'); + + output += ' +late'; + ExecutionLifecycleService.emitEvent(4322, { type: 'data', chunk: ' +late' }); + + const chunks: string[] = []; + const unsubscribe = ExecutionLifecycleService.subscribe(4322, (event) => { + if (event.type === 'data' && typeof event.chunk === 'string') { + chunks.push(event.chunk); + } + }); + expect(chunks[0]).toBe('seed +late'); + + output += ' +live'; + ExecutionLifecycleService.emitEvent(4322, { type: 'data', chunk: ' +live' }); + expect(chunks[chunks.length - 1]).toBe(' +live'); + + ExecutionLifecycleService.completeWithResult( + 4322, + createResult({ + pid: 4322, + output, + rawOutput: Buffer.from(output), + executionMethod: 'child_process', + }), + ); + + await vi.waitFor(() => { + expect(onExit).toHaveBeenCalledWith(0, undefined); + }); + unsubscribe(); + }); + + it('kills external executions and settles pending promises', async () => { + const terminate = vi.fn(); + const onExit = vi.fn(); + const handle = ExecutionLifecycleService.registerExecution(4323, { + executionMethod: 'child_process', + initialOutput: 'running', kill: terminate, }); - expect(killHandle.pid).toBe(4322); - ExecutionLifecycleService.kill(4322); + ExecutionLifecycleService.onExit(4323, onExit); + ExecutionLifecycleService.kill(4323); + + const result = await handle.result; expect(terminate).toHaveBeenCalledTimes(1); + expect(result.aborted).toBe(true); + expect(result.exitCode).toBe(130); + expect(result.output).toBe('running'); + expect(result.error?.message).toContain('Operation cancelled by user'); + expect(onExit).toHaveBeenCalledWith(130, undefined); + }); + + it('rejects duplicate execution registration for active execution IDs', () => { + ExecutionLifecycleService.registerExecution(4324, { + executionMethod: 'child_process', + }); + + expect(() => { + ExecutionLifecycleService.registerExecution(4324, { + executionMethod: 'child_process', + }); + }).toThrow('Execution 4324 is already registered.'); }); }); diff --git a/packages/core/src/services/executionLifecycleService.ts b/packages/core/src/services/executionLifecycleService.ts index 8368e5c82f..636b376c0c 100644 --- a/packages/core/src/services/executionLifecycleService.ts +++ b/packages/core/src/services/executionLifecycleService.ts @@ -64,18 +64,27 @@ export interface ExternalExecutionRegistration { isActive?: () => boolean; } -interface ManagedExecutionState { +interface ManagedExecutionBase { executionMethod: ExecutionMethod; output: string; - isVirtual: boolean; - onKill?: () => void; getBackgroundOutput?: () => string; getSubscriptionSnapshot?: () => string | AnsiOutput | undefined; +} + +interface VirtualExecutionState extends ManagedExecutionBase { + kind: 'virtual'; + onKill?: () => void; +} + +interface ExternalExecutionState extends ManagedExecutionBase { + kind: 'external'; writeInput?: (input: string) => void; kill?: () => void; isActive?: () => boolean; } +type ManagedExecutionState = VirtualExecutionState | ExternalExecutionState; + /** * Central owner for execution backgrounding lifecycle across shell and tools. */ @@ -119,20 +128,55 @@ export class ExecutionLifecycleService { return executionId; } - private static createPendingResult(executionId: number): Promise { + private static createPendingResult( + executionId: number, + ): Promise { return new Promise((resolve) => { this.activeResolvers.set(executionId, resolve); }); } + private static createAbortedResult( + executionId: number, + execution: ManagedExecutionState, + ): ExecutionResult { + const output = execution.getBackgroundOutput?.() ?? execution.output; + return { + rawOutput: Buffer.from(output, 'utf8'), + output, + exitCode: 130, + signal: null, + error: new Error('Operation cancelled by user.'), + aborted: true, + pid: executionId, + executionMethod: execution.executionMethod, + }; + } + + /** + * Resets lifecycle state for isolated unit tests. + */ + static resetForTest(): void { + this.activeExecutions.clear(); + this.activeResolvers.clear(); + this.activeListeners.clear(); + this.exitedExecutionInfo.clear(); + this.nextVirtualExecutionId = 2_000_000_000; + } + static registerExecution( executionId: number, registration: ExternalExecutionRegistration, ): ExecutionHandle { + if (this.activeExecutions.has(executionId) || this.activeResolvers.has(executionId)) { + throw new Error(`Execution ${executionId} is already registered.`); + } + this.exitedExecutionInfo.delete(executionId); + this.activeExecutions.set(executionId, { executionMethod: registration.executionMethod, output: registration.initialOutput ?? '', - isVirtual: false, + kind: 'external', getBackgroundOutput: registration.getBackgroundOutput, getSubscriptionSnapshot: registration.getSubscriptionSnapshot, writeInput: registration.writeInput, @@ -146,7 +190,7 @@ export class ExecutionLifecycleService { }; } - static createExecution( + static createVirtualExecution( initialOutput = '', onKill?: () => void, ): ExecutionHandle { @@ -155,7 +199,7 @@ export class ExecutionLifecycleService { this.activeExecutions.set(executionId, { executionMethod: 'none', output: initialOutput, - isVirtual: true, + kind: 'virtual', onKill, getBackgroundOutput: () => { const state = this.activeExecutions.get(executionId); @@ -165,7 +209,6 @@ export class ExecutionLifecycleService { const state = this.activeExecutions.get(executionId); return state?.output ?? initialOutput; }, - isActive: () => true, }); return { @@ -174,6 +217,16 @@ export class ExecutionLifecycleService { }; } + /** + * @deprecated Use createVirtualExecution() for new call sites. + */ + static createExecution( + initialOutput = '', + onKill?: () => void, + ): ExecutionHandle { + return this.createVirtualExecution(initialOutput, onKill); + } + static appendOutput(executionId: number, chunk: string): void { const execution = this.activeExecutions.get(executionId); if (!execution || chunk.length === 0) { @@ -204,7 +257,31 @@ export class ExecutionLifecycleService { this.activeResolvers.delete(executionId); } - static completeExecution( + private static settleExecution( + executionId: number, + result: ExecutionResult, + ): void { + if (!this.activeExecutions.has(executionId)) { + return; + } + + this.resolvePending(executionId, result); + this.emitEvent(executionId, { + type: 'exit', + exitCode: result.exitCode, + signal: result.signal, + }); + + this.activeListeners.delete(executionId); + this.activeExecutions.delete(executionId); + this.storeExitInfo( + executionId, + result.exitCode ?? 0, + result.signal ?? undefined, + ); + } + + static completeVirtualExecution( executionId: number, options?: ExecutionCompletionOptions, ): void { @@ -222,7 +299,7 @@ export class ExecutionLifecycleService { const output = execution.getBackgroundOutput?.() ?? execution.output; - this.resolvePending(executionId, { + this.settleExecution(executionId, { rawOutput: Buffer.from(output, 'utf8'), output, exitCode, @@ -232,37 +309,33 @@ export class ExecutionLifecycleService { pid: executionId, executionMethod: execution.executionMethod, }); - - this.emitEvent(executionId, { - type: 'exit', - exitCode, - signal, - }); - - this.activeListeners.delete(executionId); - this.activeExecutions.delete(executionId); - this.storeExitInfo(executionId, exitCode ?? 0, signal ?? undefined); } + /** + * @deprecated Use completeVirtualExecution() for new call sites. + */ + static completeExecution( + executionId: number, + options?: ExecutionCompletionOptions, + ): void { + this.completeVirtualExecution(executionId, options); + } + + static completeWithResult( + executionId: number, + result: ExecutionResult, + ): void { + this.settleExecution(executionId, result); + } + + /** + * @deprecated Use completeWithResult() for new call sites. + */ static finalizeExecution( executionId: number, result: ExecutionResult, ): void { - this.resolvePending(executionId, result); - - this.emitEvent(executionId, { - type: 'exit', - exitCode: result.exitCode, - signal: result.signal, - }); - - this.activeListeners.delete(executionId); - this.activeExecutions.delete(executionId); - this.storeExitInfo( - executionId, - result.exitCode ?? 0, - result.signal ?? undefined, - ); + this.completeWithResult(executionId, result); } static background(executionId: number): void { @@ -349,20 +422,18 @@ export class ExecutionLifecycleService { return; } - if (execution.isVirtual) { + if (execution.kind === 'virtual') { execution.onKill?.(); - this.completeExecution(executionId, { - error: new Error('Operation cancelled by user.'), - aborted: true, - exitCode: 130, - }); - return; } - execution.kill?.(); - this.activeResolvers.delete(executionId); - this.activeListeners.delete(executionId); - this.activeExecutions.delete(executionId); + if (execution.kind === 'external') { + execution.kill?.(); + } + + this.completeWithResult( + executionId, + this.createAbortedResult(executionId, execution), + ); } static isActive(executionId: number): boolean { @@ -375,11 +446,11 @@ export class ExecutionLifecycleService { } } - if (execution.isVirtual) { + if (execution.kind === 'virtual') { return true; } - if (execution.isActive) { + if (execution.kind === 'external' && execution.isActive) { try { return execution.isActive(); } catch { @@ -395,6 +466,9 @@ export class ExecutionLifecycleService { } static writeInput(executionId: number, input: string): void { - this.activeExecutions.get(executionId)?.writeInput?.(input); + const execution = this.activeExecutions.get(executionId); + if (execution?.kind === 'external') { + execution.writeInput?.(input); + } } } diff --git a/packages/core/src/services/shellExecutionService.test.ts b/packages/core/src/services/shellExecutionService.test.ts index 34c95dd4c7..c2d59c1bdf 100644 --- a/packages/core/src/services/shellExecutionService.test.ts +++ b/packages/core/src/services/shellExecutionService.test.ts @@ -21,6 +21,7 @@ import { type ShellOutputEvent, type ShellExecutionConfig, } from './shellExecutionService.js'; +import { ExecutionLifecycleService } from './executionLifecycleService.js'; import type { AnsiOutput, AnsiToken } from '../utils/terminalSerializer.js'; // Hoisted Mocks @@ -166,6 +167,7 @@ describe('ShellExecutionService', () => { beforeEach(() => { vi.clearAllMocks(); + ExecutionLifecycleService.resetForTest(); mockSerializeTerminalToObject.mockReturnValue([]); mockIsBinary.mockReturnValue(false); mockPlatform.mockReturnValue('linux'); diff --git a/packages/core/src/services/shellExecutionService.ts b/packages/core/src/services/shellExecutionService.ts index 4fef30cdd0..336db6fb41 100644 --- a/packages/core/src/services/shellExecutionService.ts +++ b/packages/core/src/services/shellExecutionService.ts @@ -28,7 +28,13 @@ import { type EnvironmentSanitizationConfig, } from './environmentSanitization.js'; import { killProcessGroup } from '../utils/process-utils.js'; -import { ExecutionLifecycleService } from './executionLifecycleService.js'; +import { + ExecutionLifecycleService, + type ExecutionCompletionOptions, + type ExecutionHandle, + type ExecutionOutputEvent, + type ExecutionResult, +} from './executionLifecycleService.js'; const { Terminal } = pkg; const MAX_CHILD_PROCESS_BUFFER_SIZE = 16 * 1024 * 1024; // 16MB @@ -67,34 +73,10 @@ function ensurePromptvarsDisabled(command: string, shell: ShellType): string { } /** A structured result from a shell command execution. */ -export interface ShellExecutionResult { - /** The raw, unprocessed output buffer. */ - rawOutput: Buffer; - /** The combined, decoded output as a string. */ - output: string; - /** The process exit code, or null if terminated by a signal. */ - exitCode: number | null; - /** The signal that terminated the process, if any. */ - signal: number | null; - /** An error object if the process failed to spawn. */ - error: Error | null; - /** A boolean indicating if the command was aborted by the user. */ - aborted: boolean; - /** The process ID of the spawned shell. */ - pid: number | undefined; - /** The method used to execute the shell command. */ - executionMethod: 'lydell-node-pty' | 'node-pty' | 'child_process' | 'none'; - /** Whether the command was moved to the background. */ - backgrounded?: boolean; -} +export type ShellExecutionResult = ExecutionResult; /** A handle for an ongoing shell execution. */ -export interface ShellExecutionHandle { - /** The process ID of the spawned shell. */ - pid: number | undefined; - /** A promise that resolves with the complete execution result. */ - result: Promise; -} +export type ShellExecutionHandle = ExecutionHandle; export interface ShellExecutionConfig { terminalWidth?: number; @@ -113,31 +95,7 @@ export interface ShellExecutionConfig { /** * Describes a structured event emitted during shell command execution. */ -export type ShellOutputEvent = - | { - /** The event contains a chunk of output data. */ - type: 'data'; - /** The decoded string chunk. */ - chunk: string | AnsiOutput; - } - | { - /** Signals that the output stream has been identified as binary. */ - type: 'binary_detected'; - } - | { - /** Provides progress updates for a binary stream. */ - type: 'binary_progress'; - /** The total number of bytes received so far. */ - bytesReceived: number; - } - | { - /** Signals that the process has exited. */ - type: 'exit'; - /** The exit code of the process, if any. */ - exitCode: number | null; - /** The signal that terminated the process, if any. */ - signal: number | null; - }; +export type ShellOutputEvent = ExecutionOutputEvent; interface ActivePty { ptyProcess: IPty; @@ -269,7 +227,7 @@ export class ShellExecutionService { initialOutput = '', onKill?: () => void, ): ShellExecutionHandle { - return ExecutionLifecycleService.createExecution(initialOutput, onKill); + return ExecutionLifecycleService.createVirtualExecution(initialOutput, onKill); } static appendVirtualOutput(pid: number, chunk: string): void { @@ -278,14 +236,9 @@ export class ShellExecutionService { static completeVirtualExecution( pid: number, - options?: { - exitCode?: number | null; - signal?: number | null; - error?: Error | null; - aborted?: boolean; - }, + options?: ExecutionCompletionOptions, ): void { - ExecutionLifecycleService.completeExecution(pid, options); + ExecutionLifecycleService.completeVirtualExecution(pid, options); } private static childProcessFallback( @@ -469,7 +422,7 @@ export class ShellExecutionService { signal: exitSignal, }; onOutputEvent(event); - ExecutionLifecycleService.finalizeExecution(child.pid, resultPayload); + ExecutionLifecycleService.completeWithResult(child.pid, resultPayload); } else { resolveWithoutPid?.(resultPayload); } @@ -862,7 +815,7 @@ export class ShellExecutionService { }; onOutputEvent(event); - ExecutionLifecycleService.finalizeExecution(ptyPid, { + ExecutionLifecycleService.completeWithResult(ptyPid, { rawOutput: Buffer.concat(outputChunks), output: getFullBufferText(headlessTerminal), exitCode,