mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-24 04:52:43 -07:00
Enable Ctrl+B backgrounding for remote agent calls
This commit is contained in:
@@ -599,6 +599,35 @@ describe('useGeminiStream', () => {
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled(); // submitQuery uses this
|
||||
});
|
||||
|
||||
it('should expose activePtyId for non-shell executing tools that report pid', () => {
|
||||
const remoteExecutingTool: TrackedExecutingToolCall = {
|
||||
request: {
|
||||
callId: 'remote-call-1',
|
||||
name: 'remote_agent_call',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-remote',
|
||||
},
|
||||
status: CoreToolCallStatus.Executing,
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
name: 'remote_agent_call',
|
||||
displayName: 'Remote Agent',
|
||||
description: 'Remote agent execution',
|
||||
build: vi.fn(),
|
||||
} as any,
|
||||
invocation: {
|
||||
getDescription: () => 'Calling remote agent',
|
||||
} as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
liveOutput: 'working...',
|
||||
pid: 4242,
|
||||
};
|
||||
|
||||
const { result } = renderTestHook([remoteExecutingTool]);
|
||||
expect(result.current.activePtyId).toBe(4242);
|
||||
});
|
||||
|
||||
it('should submit tool responses when all tool calls are completed and ready', async () => {
|
||||
const toolCall1ResponseParts: Part[] = [{ text: 'tool 1 final response' }];
|
||||
const toolCall2ResponseParts: Part[] = [{ text: 'tool 2 final response' }];
|
||||
|
||||
@@ -94,7 +94,7 @@ type ToolResponseWithParts = ToolCallResponseInfo & {
|
||||
llmContent?: PartListUnion;
|
||||
};
|
||||
|
||||
interface ShellToolData {
|
||||
interface BackgroundToolData {
|
||||
pid?: number;
|
||||
command?: string;
|
||||
initialOutput?: string;
|
||||
@@ -111,11 +111,11 @@ const SUPPRESSED_TOOL_ERRORS_NOTE =
|
||||
const LOW_VERBOSITY_FAILURE_NOTE =
|
||||
'This request failed. Press F12 for diagnostics, or run /settings and change "Error Verbosity" to full for full details.';
|
||||
|
||||
function isShellToolData(data: unknown): data is ShellToolData {
|
||||
function isBackgroundToolData(data: unknown): data is BackgroundToolData {
|
||||
if (typeof data !== 'object' || data === null) {
|
||||
return false;
|
||||
}
|
||||
const d = data as Partial<ShellToolData>;
|
||||
const d = data as Partial<BackgroundToolData>;
|
||||
return (
|
||||
(d.pid === undefined || typeof d.pid === 'number') &&
|
||||
(d.command === undefined || typeof d.command === 'string') &&
|
||||
@@ -312,12 +312,12 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
const activeToolPtyId = useMemo(() => {
|
||||
const executingShellTool = toolCalls.find(
|
||||
(tc) =>
|
||||
tc.status === 'executing' && tc.request.name === 'run_shell_command',
|
||||
const executingBackgroundableTool = toolCalls.find(
|
||||
(tc): tc is TrackedExecutingToolCall =>
|
||||
tc.status === CoreToolCallStatus.Executing &&
|
||||
typeof tc.pid === 'number',
|
||||
);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return (executingShellTool as TrackedExecutingToolCall | undefined)?.pid;
|
||||
return executingBackgroundableTool?.pid;
|
||||
}, [toolCalls]);
|
||||
|
||||
const onExec = useCallback(async (done: Promise<void>) => {
|
||||
@@ -1651,22 +1651,17 @@ export const useGeminiStream = (
|
||||
!processedMemoryToolsRef.current.has(t.request.callId),
|
||||
);
|
||||
|
||||
// Handle backgrounded shell tools
|
||||
// Handle tools moved to the background (shell + remote agents).
|
||||
completedAndReadyToSubmitTools.forEach((t) => {
|
||||
const isShell = t.request.name === 'run_shell_command';
|
||||
// Access result from the tracked tool call response
|
||||
const response = t.response as ToolResponseWithParts;
|
||||
const rawData = response?.data;
|
||||
const data = isShellToolData(rawData) ? rawData : undefined;
|
||||
|
||||
// Use data.pid for shell commands moved to the background.
|
||||
const data = isBackgroundToolData(rawData) ? rawData : undefined;
|
||||
const pid = data?.pid;
|
||||
|
||||
if (isShell && pid) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const command = (data?.['command'] as string) ?? 'shell';
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const initialOutput = (data?.['initialOutput'] as string) ?? '';
|
||||
if (pid) {
|
||||
const command = data.command ?? t.request.name;
|
||||
const initialOutput = data.initialOutput ?? '';
|
||||
|
||||
registerBackgroundShell(pid, command, initialOutput);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import type { RemoteAgentDefinition } from './types.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||
import { ShellExecutionService } from '../services/shellExecutionService.js';
|
||||
|
||||
// Mock A2AClientManager
|
||||
vi.mock('./a2a-client-manager.js', () => ({
|
||||
@@ -583,6 +584,88 @@ describe('RemoteAgentInvocation', () => {
|
||||
'Generating...\n\nArtifact (Result):\nPart 1 Part 2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should support Ctrl+B backgrounding through ShellExecutionService', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
|
||||
let releaseSecondChunk: (() => void) | undefined;
|
||||
const secondChunkGate = new Promise<void>((resolve) => {
|
||||
releaseSecondChunk = resolve;
|
||||
});
|
||||
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Chunk 1' }],
|
||||
};
|
||||
await secondChunkGate;
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-2',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Chunk 2' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
let pid: number | undefined;
|
||||
const onExit = vi.fn();
|
||||
let unsubscribeOnExit: (() => void) | undefined;
|
||||
const streamedOutputChunks: string[] = [];
|
||||
let unsubscribeStream: (() => void) | undefined;
|
||||
|
||||
const updateOutput = vi.fn((output: unknown) => {
|
||||
if (output === 'Chunk 1' && pid) {
|
||||
ShellExecutionService.background(pid);
|
||||
unsubscribeStream = ShellExecutionService.subscribe(pid, (event) => {
|
||||
if (event.type === 'data' && typeof event.chunk === 'string') {
|
||||
streamedOutputChunks.push(event.chunk);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'long task' },
|
||||
mockMessageBus,
|
||||
);
|
||||
|
||||
const resultPromise = invocation.execute(
|
||||
new AbortController().signal,
|
||||
updateOutput,
|
||||
undefined,
|
||||
(newPid) => {
|
||||
pid = newPid;
|
||||
unsubscribeOnExit = ShellExecutionService.onExit(newPid, onExit);
|
||||
},
|
||||
);
|
||||
|
||||
const result = await resultPromise;
|
||||
expect(pid).toBeDefined();
|
||||
expect(result.returnDisplay).toContain(
|
||||
'Remote agent moved to background',
|
||||
);
|
||||
expect(result.data).toMatchObject({
|
||||
pid,
|
||||
initialOutput: 'Chunk 1',
|
||||
});
|
||||
|
||||
releaseSecondChunk?.();
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onExit).toHaveBeenCalledWith(0, undefined);
|
||||
});
|
||||
await vi.waitFor(() => {
|
||||
expect(streamedOutputChunks.join('')).toContain('Chunk 2');
|
||||
});
|
||||
|
||||
unsubscribeStream?.();
|
||||
unsubscribeOnExit?.();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Confirmations', () => {
|
||||
|
||||
@@ -27,6 +27,7 @@ import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import { ShellExecutionService } from '../services/shellExecutionService.js';
|
||||
|
||||
/**
|
||||
* Authentication handler implementation using Google Application Default Credentials (ADC).
|
||||
@@ -145,102 +146,192 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
};
|
||||
}
|
||||
|
||||
private publishBackgroundDelta(
|
||||
pid: number,
|
||||
previousOutput: string,
|
||||
nextOutput: string,
|
||||
): string {
|
||||
if (nextOutput.length === 0 || nextOutput === previousOutput) {
|
||||
return previousOutput;
|
||||
}
|
||||
|
||||
if (nextOutput.startsWith(previousOutput)) {
|
||||
ShellExecutionService.appendVirtualOutput(
|
||||
pid,
|
||||
nextOutput.slice(previousOutput.length),
|
||||
);
|
||||
return nextOutput;
|
||||
}
|
||||
|
||||
// If the reassembled output changes non-monotonically, resync by appending
|
||||
// the full latest snapshot with a clear separator.
|
||||
ShellExecutionService.appendVirtualOutput(
|
||||
pid,
|
||||
`\n\n[Output updated]\n${nextOutput}`,
|
||||
);
|
||||
return nextOutput;
|
||||
}
|
||||
|
||||
async execute(
|
||||
_signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
_shellExecutionConfig?: unknown,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
): Promise<ToolResult> {
|
||||
// 1. Ensure the agent is loaded (cached by manager)
|
||||
// We assume the user has provided an access token via some mechanism (TODO),
|
||||
// or we rely on ADC.
|
||||
const reassembler = new A2AResultReassembler();
|
||||
try {
|
||||
const priorState = RemoteAgentInvocation.sessionState.get(
|
||||
this.definition.name,
|
||||
);
|
||||
if (priorState) {
|
||||
this.contextId = priorState.contextId;
|
||||
this.taskId = priorState.taskId;
|
||||
}
|
||||
const executionController = new AbortController();
|
||||
const onAbort = () => executionController.abort();
|
||||
_signal.addEventListener('abort', onAbort, { once: true });
|
||||
|
||||
const authHandler = await this.getAuthHandler();
|
||||
const { pid, result } = ShellExecutionService.createVirtualExecution(
|
||||
'',
|
||||
() => executionController.abort(),
|
||||
);
|
||||
if (pid === undefined) {
|
||||
_signal.removeEventListener('abort', onAbort);
|
||||
return {
|
||||
llmContent: [
|
||||
{ text: 'Error calling remote agent: missing execution pid.' },
|
||||
],
|
||||
returnDisplay: 'Error calling remote agent: missing execution pid.',
|
||||
error: {
|
||||
message: 'Error calling remote agent: missing execution pid.',
|
||||
},
|
||||
};
|
||||
}
|
||||
const backgroundPid = pid;
|
||||
setPidCallback?.(backgroundPid);
|
||||
|
||||
if (!this.clientManager.getClient(this.definition.name)) {
|
||||
await this.clientManager.loadAgent(
|
||||
const run = async () => {
|
||||
let lastOutput = '';
|
||||
try {
|
||||
const priorState = RemoteAgentInvocation.sessionState.get(
|
||||
this.definition.name,
|
||||
this.definition.agentCardUrl,
|
||||
authHandler,
|
||||
);
|
||||
}
|
||||
if (priorState) {
|
||||
this.contextId = priorState.contextId;
|
||||
this.taskId = priorState.taskId;
|
||||
}
|
||||
|
||||
const message = this.params.query;
|
||||
const authHandler = await this.getAuthHandler();
|
||||
|
||||
const stream = this.clientManager.sendMessageStream(
|
||||
this.definition.name,
|
||||
message,
|
||||
{
|
||||
if (!this.clientManager.getClient(this.definition.name)) {
|
||||
await this.clientManager.loadAgent(
|
||||
this.definition.name,
|
||||
this.definition.agentCardUrl,
|
||||
authHandler,
|
||||
);
|
||||
}
|
||||
|
||||
const stream = this.clientManager.sendMessageStream(
|
||||
this.definition.name,
|
||||
this.params.query,
|
||||
{
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
signal: executionController.signal,
|
||||
},
|
||||
);
|
||||
|
||||
let finalResponse: SendMessageResult | undefined;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (executionController.signal.aborted) {
|
||||
throw new Error('Operation aborted');
|
||||
}
|
||||
finalResponse = chunk;
|
||||
reassembler.update(chunk);
|
||||
|
||||
const currentOutput = reassembler.toString();
|
||||
lastOutput = this.publishBackgroundDelta(
|
||||
backgroundPid,
|
||||
lastOutput,
|
||||
currentOutput,
|
||||
);
|
||||
if (updateOutput) {
|
||||
updateOutput(currentOutput);
|
||||
}
|
||||
|
||||
const {
|
||||
contextId: newContextId,
|
||||
taskId: newTaskId,
|
||||
clearTaskId,
|
||||
} = extractIdsFromResponse(chunk);
|
||||
|
||||
if (newContextId) {
|
||||
this.contextId = newContextId;
|
||||
}
|
||||
|
||||
this.taskId = clearTaskId ? undefined : (newTaskId ?? this.taskId);
|
||||
}
|
||||
|
||||
if (!finalResponse) {
|
||||
throw new Error('No response from remote agent.');
|
||||
}
|
||||
|
||||
debugLogger.debug(
|
||||
`[RemoteAgent] Final response from ${this.definition.name}:\n${JSON.stringify(finalResponse, null, 2)}`,
|
||||
);
|
||||
|
||||
ShellExecutionService.completeVirtualExecution(backgroundPid, {
|
||||
exitCode: 0,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
const partialOutput = reassembler.toString();
|
||||
lastOutput = this.publishBackgroundDelta(
|
||||
backgroundPid,
|
||||
lastOutput,
|
||||
partialOutput,
|
||||
);
|
||||
const errorMessage = `Error calling remote agent: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`;
|
||||
ShellExecutionService.completeVirtualExecution(backgroundPid, {
|
||||
error: new Error(errorMessage),
|
||||
aborted: executionController.signal.aborted,
|
||||
exitCode: executionController.signal.aborted ? 130 : 1,
|
||||
});
|
||||
} finally {
|
||||
_signal.removeEventListener('abort', onAbort);
|
||||
// Persist state even on partial failures or aborts to maintain conversational continuity.
|
||||
RemoteAgentInvocation.sessionState.set(this.definition.name, {
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
signal: _signal,
|
||||
},
|
||||
);
|
||||
|
||||
let finalResponse: SendMessageResult | undefined;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (_signal.aborted) {
|
||||
throw new Error('Operation aborted');
|
||||
}
|
||||
finalResponse = chunk;
|
||||
reassembler.update(chunk);
|
||||
|
||||
if (updateOutput) {
|
||||
updateOutput(reassembler.toString());
|
||||
}
|
||||
|
||||
const {
|
||||
contextId: newContextId,
|
||||
taskId: newTaskId,
|
||||
clearTaskId,
|
||||
} = extractIdsFromResponse(chunk);
|
||||
|
||||
if (newContextId) {
|
||||
this.contextId = newContextId;
|
||||
}
|
||||
|
||||
this.taskId = clearTaskId ? undefined : (newTaskId ?? this.taskId);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (!finalResponse) {
|
||||
throw new Error('No response from remote agent.');
|
||||
}
|
||||
|
||||
const finalOutput = reassembler.toString();
|
||||
|
||||
debugLogger.debug(
|
||||
`[RemoteAgent] Final response from ${this.definition.name}:\n${JSON.stringify(finalResponse, null, 2)}`,
|
||||
);
|
||||
void run();
|
||||
const executionResult = await result;
|
||||
|
||||
if (executionResult.backgrounded) {
|
||||
const command = `${this.getDescription()}: ${this.params.query}`;
|
||||
const backgroundMessage = `Remote agent moved to background (PID: ${backgroundPid}). Output hidden. Press Ctrl+B to view.`;
|
||||
return {
|
||||
llmContent: [{ text: finalOutput }],
|
||||
returnDisplay: finalOutput,
|
||||
llmContent: [{ text: backgroundMessage }],
|
||||
returnDisplay: backgroundMessage,
|
||||
data: {
|
||||
pid: backgroundPid,
|
||||
command,
|
||||
initialOutput: executionResult.output,
|
||||
},
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const partialOutput = reassembler.toString();
|
||||
const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`;
|
||||
const fullDisplay = partialOutput
|
||||
? `${partialOutput}\n\n${errorMessage}`
|
||||
: errorMessage;
|
||||
}
|
||||
|
||||
if (executionResult.error) {
|
||||
const fullDisplay = executionResult.output
|
||||
? `${executionResult.output}\n\n${executionResult.error.message}`
|
||||
: executionResult.error.message;
|
||||
return {
|
||||
llmContent: [{ text: fullDisplay }],
|
||||
returnDisplay: fullDisplay,
|
||||
error: { message: errorMessage },
|
||||
error: { message: executionResult.error.message },
|
||||
};
|
||||
} finally {
|
||||
// Persist state even on partial failures or aborts to maintain conversational continuity.
|
||||
RemoteAgentInvocation.sessionState.set(this.definition.name, {
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: [{ text: executionResult.output }],
|
||||
returnDisplay: executionResult.output,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
type ToolResult,
|
||||
type AnyDeclarativeTool,
|
||||
type ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { HookSystem } from '../hooks/hookSystem.js';
|
||||
@@ -37,6 +38,30 @@ class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
||||
}
|
||||
}
|
||||
|
||||
class MockPidInvocation extends BaseToolInvocation<
|
||||
{ key?: string },
|
||||
ToolResult
|
||||
> {
|
||||
constructor(params: { key?: string }, messageBus: MessageBus) {
|
||||
super(params, messageBus);
|
||||
}
|
||||
getDescription() {
|
||||
return 'mock-pid';
|
||||
}
|
||||
async execute(
|
||||
_signal: AbortSignal,
|
||||
_updateOutput?: (output: ToolLiveOutput) => void,
|
||||
_shellExecutionConfig?: unknown,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
) {
|
||||
setPidCallback?.(4242);
|
||||
return {
|
||||
llmContent: 'pid',
|
||||
returnDisplay: 'pid',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
describe('executeToolWithHooks', () => {
|
||||
let messageBus: MessageBus;
|
||||
let mockTool: AnyDeclarativeTool;
|
||||
@@ -258,4 +283,26 @@ describe('executeToolWithHooks', () => {
|
||||
expect(invocation.params.key).toBe('original');
|
||||
expect(mockTool.build).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass pid callback through for non-shell invocations', async () => {
|
||||
const invocation = new MockPidInvocation({}, messageBus);
|
||||
const abortSignal = new AbortController().signal;
|
||||
const setPidCallback = vi.fn();
|
||||
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(undefined);
|
||||
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
|
||||
|
||||
await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
mockTool,
|
||||
undefined,
|
||||
undefined,
|
||||
setPidCallback,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(setPidCallback).toHaveBeenCalledWith(4242);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,7 +15,7 @@ import type {
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ShellExecutionConfig } from '../index.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import type { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||
|
||||
/**
|
||||
@@ -154,22 +154,23 @@ export async function executeToolWithHooks(
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the actual tool
|
||||
let toolResult: ToolResult;
|
||||
if (setPidCallback && invocation instanceof ShellToolInvocation) {
|
||||
toolResult = await invocation.execute(
|
||||
signal,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
);
|
||||
} else {
|
||||
toolResult = await invocation.execute(
|
||||
signal,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
);
|
||||
}
|
||||
// Execute the actual tool. Some tools (not just shell) can optionally expose
|
||||
// a PID-like handle via a fourth parameter.
|
||||
const invocationWithPidSupport = invocation as AnyToolInvocation & {
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (outputChunk: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
): Promise<ToolResult>;
|
||||
};
|
||||
|
||||
const toolResult: ToolResult = await invocationWithPidSupport.execute(
|
||||
signal,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
);
|
||||
|
||||
// Append notification if parameters were modified
|
||||
if (inputWasModified) {
|
||||
|
||||
@@ -534,6 +534,51 @@ describe('ToolExecutor', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should report PID updates for non-shell tools that support backgrounding', async () => {
|
||||
const mockTool = new MockTool({
|
||||
name: 'remote_agent_call',
|
||||
description: 'Remote agent call',
|
||||
});
|
||||
const invocation = mockTool.build({});
|
||||
|
||||
const testPid = 67890;
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
||||
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
|
||||
setPidCallback?.(testPid);
|
||||
return { llmContent: 'done', returnDisplay: 'done' };
|
||||
},
|
||||
);
|
||||
|
||||
const scheduledCall: ScheduledToolCall = {
|
||||
status: CoreToolCallStatus.Scheduled,
|
||||
request: {
|
||||
callId: 'call-remote-pid',
|
||||
name: 'remote_agent_call',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-remote-pid',
|
||||
},
|
||||
tool: mockTool,
|
||||
invocation: invocation as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
|
||||
const onUpdateToolCall = vi.fn();
|
||||
|
||||
await executor.execute({
|
||||
call: scheduledCall,
|
||||
signal: new AbortController().signal,
|
||||
onUpdateToolCall,
|
||||
});
|
||||
|
||||
expect(onUpdateToolCall).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
status: CoreToolCallStatus.Executing,
|
||||
pid: testPid,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return cancelled result with partial output when signal is aborted', async () => {
|
||||
const mockTool = new MockTool({
|
||||
name: 'slowTool',
|
||||
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
type ToolLiveOutput,
|
||||
} from '../index.js';
|
||||
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import { executeToolWithHooks } from '../core/coreToolHookTriggers.js';
|
||||
import {
|
||||
@@ -89,43 +88,28 @@ export class ToolExecutor {
|
||||
let completedToolCall: CompletedToolCall;
|
||||
|
||||
try {
|
||||
let promise: Promise<ToolResult>;
|
||||
if (invocation instanceof ShellToolInvocation) {
|
||||
const setPidCallback = (pid: number) => {
|
||||
const executingCall: ExecutingToolCall = {
|
||||
...call,
|
||||
status: CoreToolCallStatus.Executing,
|
||||
tool,
|
||||
invocation,
|
||||
pid,
|
||||
startTime: 'startTime' in call ? call.startTime : undefined,
|
||||
};
|
||||
onUpdateToolCall(executingCall);
|
||||
const setPidCallback = (pid: number) => {
|
||||
const executingCall: ExecutingToolCall = {
|
||||
...call,
|
||||
status: CoreToolCallStatus.Executing,
|
||||
tool,
|
||||
invocation,
|
||||
pid,
|
||||
startTime: 'startTime' in call ? call.startTime : undefined,
|
||||
};
|
||||
promise = executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
signal,
|
||||
tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
);
|
||||
} else {
|
||||
promise = executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
signal,
|
||||
tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
undefined,
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
);
|
||||
}
|
||||
onUpdateToolCall(executingCall);
|
||||
};
|
||||
const promise = executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
signal,
|
||||
tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
);
|
||||
|
||||
const toolResult: ToolResult = await promise;
|
||||
|
||||
|
||||
@@ -152,6 +152,11 @@ interface ActiveChildProcess {
|
||||
};
|
||||
}
|
||||
|
||||
interface ActiveVirtualProcess {
|
||||
output: string;
|
||||
onKill?: () => void;
|
||||
}
|
||||
|
||||
const getFullBufferText = (terminal: pkg.Terminal): string => {
|
||||
const buffer = terminal.buffer.active;
|
||||
const lines: string[] = [];
|
||||
@@ -198,6 +203,10 @@ const getFullBufferText = (terminal: pkg.Terminal): string => {
|
||||
export class ShellExecutionService {
|
||||
private static activePtys = new Map<number, ActivePty>();
|
||||
private static activeChildProcesses = new Map<number, ActiveChildProcess>();
|
||||
private static activeVirtualProcesses = new Map<
|
||||
number,
|
||||
ActiveVirtualProcess
|
||||
>();
|
||||
private static exitedPtyInfo = new Map<
|
||||
number,
|
||||
{ exitCode: number; signal?: number }
|
||||
@@ -210,6 +219,7 @@ export class ShellExecutionService {
|
||||
number,
|
||||
Set<(event: ShellOutputEvent) => void>
|
||||
>();
|
||||
private static nextVirtualPid = 2_000_000_000;
|
||||
/**
|
||||
* Executes a shell command using `node-pty`, capturing all output and lifecycle events.
|
||||
*
|
||||
@@ -292,6 +302,100 @@ export class ShellExecutionService {
|
||||
}
|
||||
}
|
||||
|
||||
private static allocateVirtualPid(): number {
|
||||
let pid = ++this.nextVirtualPid;
|
||||
while (
|
||||
this.activePtys.has(pid) ||
|
||||
this.activeChildProcesses.has(pid) ||
|
||||
this.activeVirtualProcesses.has(pid)
|
||||
) {
|
||||
pid = ++this.nextVirtualPid;
|
||||
}
|
||||
return pid;
|
||||
}
|
||||
|
||||
static createVirtualExecution(
|
||||
initialOutput = '',
|
||||
onKill?: () => void,
|
||||
): ShellExecutionHandle {
|
||||
const pid = this.allocateVirtualPid();
|
||||
this.activeVirtualProcesses.set(pid, {
|
||||
output: initialOutput,
|
||||
onKill,
|
||||
});
|
||||
|
||||
const result = new Promise<ShellExecutionResult>((resolve) => {
|
||||
this.activeResolvers.set(pid, resolve);
|
||||
});
|
||||
|
||||
return { pid, result };
|
||||
}
|
||||
|
||||
static appendVirtualOutput(pid: number, chunk: string): void {
|
||||
const virtual = this.activeVirtualProcesses.get(pid);
|
||||
if (!virtual || chunk.length === 0) {
|
||||
return;
|
||||
}
|
||||
virtual.output += chunk;
|
||||
this.emitEvent(pid, { type: 'data', chunk });
|
||||
}
|
||||
|
||||
static completeVirtualExecution(
|
||||
pid: number,
|
||||
options?: {
|
||||
exitCode?: number | null;
|
||||
signal?: number | null;
|
||||
error?: Error | null;
|
||||
aborted?: boolean;
|
||||
},
|
||||
): void {
|
||||
const virtual = this.activeVirtualProcesses.get(pid);
|
||||
if (!virtual) {
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
error = null,
|
||||
aborted = false,
|
||||
exitCode = error ? 1 : 0,
|
||||
signal = null,
|
||||
} = options ?? {};
|
||||
|
||||
const resolve = this.activeResolvers.get(pid);
|
||||
if (resolve) {
|
||||
resolve({
|
||||
rawOutput: Buffer.from(virtual.output, 'utf8'),
|
||||
output: virtual.output,
|
||||
exitCode,
|
||||
signal,
|
||||
error,
|
||||
aborted,
|
||||
pid,
|
||||
executionMethod: 'none',
|
||||
});
|
||||
this.activeResolvers.delete(pid);
|
||||
}
|
||||
|
||||
this.emitEvent(pid, {
|
||||
type: 'exit',
|
||||
exitCode,
|
||||
signal,
|
||||
});
|
||||
this.activeListeners.delete(pid);
|
||||
this.activeVirtualProcesses.delete(pid);
|
||||
|
||||
this.exitedPtyInfo.set(pid, {
|
||||
exitCode: exitCode ?? 0,
|
||||
signal: signal ?? undefined,
|
||||
});
|
||||
setTimeout(
|
||||
() => {
|
||||
this.exitedPtyInfo.delete(pid);
|
||||
},
|
||||
5 * 60 * 1000,
|
||||
).unref();
|
||||
}
|
||||
|
||||
private static childProcessFallback(
|
||||
commandToExecute: string,
|
||||
cwd: string,
|
||||
@@ -933,6 +1037,10 @@ export class ShellExecutionService {
|
||||
}
|
||||
|
||||
static isPtyActive(pid: number): boolean {
|
||||
if (this.activeVirtualProcesses.has(pid)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this.activeChildProcesses.has(pid)) {
|
||||
try {
|
||||
return process.kill(pid, 0);
|
||||
@@ -984,6 +1092,15 @@ export class ShellExecutionService {
|
||||
return () => {
|
||||
activeChild?.process.removeListener('exit', listener);
|
||||
};
|
||||
} else if (this.activeVirtualProcesses.has(pid)) {
|
||||
const listener = (event: ShellOutputEvent) => {
|
||||
if (event.type === 'exit') {
|
||||
callback(event.exitCode ?? 0, event.signal ?? undefined);
|
||||
unsubscribe();
|
||||
}
|
||||
};
|
||||
const unsubscribe = this.subscribe(pid, listener);
|
||||
return unsubscribe;
|
||||
} else {
|
||||
// Check if it already exited recently
|
||||
const exitedInfo = this.exitedPtyInfo.get(pid);
|
||||
@@ -1002,8 +1119,17 @@ export class ShellExecutionService {
|
||||
static kill(pid: number): void {
|
||||
const activePty = this.activePtys.get(pid);
|
||||
const activeChild = this.activeChildProcesses.get(pid);
|
||||
const activeVirtual = this.activeVirtualProcesses.get(pid);
|
||||
|
||||
if (activeChild) {
|
||||
if (activeVirtual) {
|
||||
activeVirtual.onKill?.();
|
||||
this.completeVirtualExecution(pid, {
|
||||
error: new Error('Operation cancelled by user.'),
|
||||
aborted: true,
|
||||
exitCode: 130,
|
||||
});
|
||||
return;
|
||||
} else if (activeChild) {
|
||||
killProcessGroup({ pid }).catch(() => {});
|
||||
this.activeChildProcesses.delete(pid);
|
||||
} else if (activePty) {
|
||||
@@ -1029,6 +1155,7 @@ export class ShellExecutionService {
|
||||
|
||||
const activePty = this.activePtys.get(pid);
|
||||
const activeChild = this.activeChildProcesses.get(pid);
|
||||
const activeVirtual = this.activeVirtualProcesses.get(pid);
|
||||
|
||||
if (activePty) {
|
||||
output = getFullBufferText(activePty.headlessTerminal);
|
||||
@@ -1057,6 +1184,19 @@ export class ShellExecutionService {
|
||||
executionMethod: 'child_process',
|
||||
backgrounded: true,
|
||||
});
|
||||
} else if (activeVirtual) {
|
||||
output = activeVirtual.output;
|
||||
resolve({
|
||||
rawOutput,
|
||||
output,
|
||||
exitCode: null,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid,
|
||||
executionMethod: 'none',
|
||||
backgrounded: true,
|
||||
});
|
||||
}
|
||||
|
||||
this.activeResolvers.delete(pid);
|
||||
@@ -1075,6 +1215,7 @@ export class ShellExecutionService {
|
||||
// Send current buffer content immediately
|
||||
const activePty = this.activePtys.get(pid);
|
||||
const activeChild = this.activeChildProcesses.get(pid);
|
||||
const activeVirtual = this.activeVirtualProcesses.get(pid);
|
||||
|
||||
if (activePty) {
|
||||
// Use serializeTerminalToObject to preserve colors and structure
|
||||
@@ -1096,6 +1237,8 @@ export class ShellExecutionService {
|
||||
if (output) {
|
||||
listener({ type: 'data', chunk: output });
|
||||
}
|
||||
} else if (activeVirtual?.output) {
|
||||
listener({ type: 'data', chunk: activeVirtual.output });
|
||||
}
|
||||
|
||||
return () => {
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ShellExecutionService } from './shellExecutionService.js';
|
||||
|
||||
describe('ShellExecutionService virtual executions', () => {
|
||||
it('completes a virtual execution in the foreground', async () => {
|
||||
const { pid, result } = ShellExecutionService.createVirtualExecution();
|
||||
const onExit = vi.fn();
|
||||
const unsubscribe = ShellExecutionService.onExit(pid!, onExit);
|
||||
|
||||
ShellExecutionService.appendVirtualOutput(pid!, 'Hello');
|
||||
ShellExecutionService.appendVirtualOutput(pid!, ' World');
|
||||
ShellExecutionService.completeVirtualExecution(pid!, { exitCode: 0 });
|
||||
|
||||
const executionResult = await result;
|
||||
|
||||
expect(executionResult.output).toBe('Hello World');
|
||||
expect(executionResult.backgrounded).toBeUndefined();
|
||||
expect(executionResult.exitCode).toBe(0);
|
||||
expect(executionResult.error).toBeNull();
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onExit).toHaveBeenCalledWith(0, undefined);
|
||||
});
|
||||
|
||||
unsubscribe();
|
||||
});
|
||||
|
||||
it('supports backgrounding virtual executions and streaming additional output', async () => {
|
||||
const { pid, result } = ShellExecutionService.createVirtualExecution();
|
||||
const chunks: string[] = [];
|
||||
const onExit = vi.fn();
|
||||
|
||||
const unsubscribeStream = ShellExecutionService.subscribe(pid!, (event) => {
|
||||
if (event.type === 'data' && typeof event.chunk === 'string') {
|
||||
chunks.push(event.chunk);
|
||||
}
|
||||
});
|
||||
const unsubscribeExit = ShellExecutionService.onExit(pid!, onExit);
|
||||
|
||||
ShellExecutionService.appendVirtualOutput(pid!, 'Chunk 1');
|
||||
ShellExecutionService.background(pid!);
|
||||
|
||||
const backgroundResult = await result;
|
||||
expect(backgroundResult.backgrounded).toBe(true);
|
||||
expect(backgroundResult.output).toBe('Chunk 1');
|
||||
|
||||
ShellExecutionService.appendVirtualOutput(pid!, '\nChunk 2');
|
||||
ShellExecutionService.completeVirtualExecution(pid!, { exitCode: 0 });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(chunks.join('')).toContain('Chunk 2');
|
||||
expect(onExit).toHaveBeenCalledWith(0, undefined);
|
||||
});
|
||||
|
||||
unsubscribeStream();
|
||||
unsubscribeExit();
|
||||
});
|
||||
|
||||
it('kills virtual executions via the existing kill API', async () => {
|
||||
const onKill = vi.fn();
|
||||
const { pid, result } = ShellExecutionService.createVirtualExecution(
|
||||
'',
|
||||
onKill,
|
||||
);
|
||||
|
||||
ShellExecutionService.appendVirtualOutput(pid!, 'work');
|
||||
ShellExecutionService.kill(pid!);
|
||||
|
||||
const killResult = await result;
|
||||
expect(onKill).toHaveBeenCalledTimes(1);
|
||||
expect(killResult.aborted).toBe(true);
|
||||
expect(killResult.exitCode).toBe(130);
|
||||
expect(killResult.error?.message).toContain('Operation cancelled by user');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user