feat(core): experimental in-progress steering hints (2 of 2) (#19307)

This commit is contained in:
joshualitt
2026-02-18 14:05:50 -08:00
committed by GitHub
parent 81c8893e05
commit 87f5dd15d6
37 changed files with 1280 additions and 48 deletions

View File

@@ -2037,6 +2037,215 @@ describe('LocalAgentExecutor', () => {
expect(recoveryEvent.success).toBe(true);
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
});
describe('Model Steering', () => {
let configWithHints: Config;
beforeEach(() => {
configWithHints = makeFakeConfig({ modelSteering: true });
vi.spyOn(configWithHints, 'getAgentRegistry').mockReturnValue({
getAllAgentNames: () => [],
} as unknown as AgentRegistry);
vi.spyOn(configWithHints, 'getToolRegistry').mockReturnValue(
parentToolRegistry,
);
});
it('should inject user hints into the next turn after they are added', async () => {
const definition = createTestDefinition();
const executor = await LocalAgentExecutor.create(
definition,
configWithHints,
);
// Turn 1: Model calls LS
mockModelResponse(
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
'T1: Listing',
);
// We use a manual promise to ensure the hint is added WHILE Turn 1 is "running"
let resolveToolCall: (value: unknown) => void;
const toolCallPromise = new Promise((resolve) => {
resolveToolCall = resolve;
});
mockScheduleAgentTools.mockReturnValueOnce(toolCallPromise);
// Turn 2: Model calls complete_task
mockModelResponse(
[
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'Done' },
id: 'call2',
},
],
'T2: Done',
);
const runPromise = executor.run({ goal: 'Hint test' }, signal);
// Give the loop a chance to start and register the listener
await vi.advanceTimersByTimeAsync(1);
configWithHints.userHintService.addUserHint('Initial Hint');
// Resolve the tool call to complete Turn 1
resolveToolCall!([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'p1',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'file1.txt',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { result: 'file1.txt' },
id: 'call1',
},
},
],
},
},
]);
await runPromise;
// The first call to sendMessageStream should NOT contain the hint (it was added after start)
// The SECOND call to sendMessageStream SHOULD contain the hint
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
const secondTurnMessageParts = mockSendMessageStream.mock.calls[1][1];
expect(secondTurnMessageParts).toContainEqual(
expect.objectContaining({
text: expect.stringContaining('Initial Hint'),
}),
);
});
it('should NOT inject legacy hints added before executor was created', async () => {
const definition = createTestDefinition();
configWithHints.userHintService.addUserHint('Legacy Hint');
const executor = await LocalAgentExecutor.create(
definition,
configWithHints,
);
mockModelResponse([
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'Done' },
id: 'call1',
},
]);
await executor.run({ goal: 'Isolation test' }, signal);
// The first call to sendMessageStream should NOT contain the legacy hint
expect(mockSendMessageStream).toHaveBeenCalled();
const firstTurnMessageParts = mockSendMessageStream.mock.calls[0][1];
// We expect only the goal, no hints injected at turn start
for (const part of firstTurnMessageParts) {
if (part.text) {
expect(part.text).not.toContain('Legacy Hint');
}
}
});
it('should inject mid-execution hints into subsequent turns', async () => {
const definition = createTestDefinition();
const executor = await LocalAgentExecutor.create(
definition,
configWithHints,
);
// Turn 1: Model calls LS
mockModelResponse(
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
'T1: Listing',
);
// We use a manual promise to ensure the hint is added WHILE Turn 1 is "running"
let resolveToolCall: (value: unknown) => void;
const toolCallPromise = new Promise((resolve) => {
resolveToolCall = resolve;
});
mockScheduleAgentTools.mockReturnValueOnce(toolCallPromise);
// Turn 2: Model calls complete_task
mockModelResponse(
[
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'Done' },
id: 'call2',
},
],
'T2: Done',
);
// Start execution
const runPromise = executor.run({ goal: 'Mid-turn hint test' }, signal);
// Small delay to ensure the run loop has reached the await and registered listener
await vi.advanceTimersByTimeAsync(1);
// Add the hint while the tool call is pending
configWithHints.userHintService.addUserHint('Corrective Hint');
// Now resolve the tool call to complete Turn 1
resolveToolCall!([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'p1',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'file1.txt',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { result: 'file1.txt' },
id: 'call1',
},
},
],
},
},
]);
await runPromise;
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
// The second turn (turn 1) should contain the corrective hint.
const secondTurnMessageParts = mockSendMessageStream.mock.calls[1][1];
expect(secondTurnMessageParts).toContainEqual(
expect.objectContaining({
text: expect.stringContaining('Corrective Hint'),
}),
);
});
});
});
describe('Chat Compression', () => {
const mockWorkResponse = (id: string) => {

View File

@@ -60,6 +60,7 @@ import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
import { DeadlineTimer } from '../utils/deadlineTimer.js';
import { LlmRole } from '../telemetry/types.js';
import { formatUserHintsForModel } from '../utils/fastAckHelper.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -463,45 +464,82 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
const query = this.definition.promptConfig.query
? templateString(this.definition.promptConfig.query, augmentedInputs)
: DEFAULT_QUERY_STRING;
let currentMessage: Content = { role: 'user', parts: [{ text: query }] };
while (true) {
// Check for termination conditions like max turns.
const reason = this.checkTermination(turnCounter, maxTurns);
if (reason) {
terminateReason = reason;
break;
}
const pendingHintsQueue: string[] = [];
const hintListener = (hint: string) => {
pendingHintsQueue.push(hint);
};
// Capture the index of the last hint before starting to avoid re-injecting old hints.
// NOTE: Hints added AFTER this point will be broadcast to all currently running
// local agents via the listener below.
const startIndex =
this.runtimeContext.userHintService.getLatestHintIndex();
this.runtimeContext.userHintService.onUserHint(hintListener);
// Check for timeout or external abort.
if (combinedSignal.aborted) {
// Determine which signal caused the abort.
terminateReason = deadlineTimer.signal.aborted
? AgentTerminateMode.TIMEOUT
: AgentTerminateMode.ABORTED;
break;
}
try {
const initialHints =
this.runtimeContext.userHintService.getUserHintsAfter(startIndex);
const formattedInitialHints = formatUserHintsForModel(initialHints);
const turnResult = await this.executeTurn(
chat,
currentMessage,
turnCounter++,
combinedSignal,
deadlineTimer.signal,
onWaitingForConfirmation,
);
let currentMessage: Content = formattedInitialHints
? {
role: 'user',
parts: [{ text: formattedInitialHints }, { text: query }],
}
: { role: 'user', parts: [{ text: query }] };
if (turnResult.status === 'stop') {
terminateReason = turnResult.terminateReason;
// Only set finalResult if the turn provided one (e.g., error or goal).
if (turnResult.finalResult) {
finalResult = turnResult.finalResult;
while (true) {
// Check for termination conditions like max turns.
const reason = this.checkTermination(turnCounter, maxTurns);
if (reason) {
terminateReason = reason;
break;
}
break; // Exit the loop for *any* stop reason.
}
// If status is 'continue', update message for the next loop
currentMessage = turnResult.nextMessage;
// Check for timeout or external abort.
if (combinedSignal.aborted) {
// Determine which signal caused the abort.
terminateReason = deadlineTimer.signal.aborted
? AgentTerminateMode.TIMEOUT
: AgentTerminateMode.ABORTED;
break;
}
const turnResult = await this.executeTurn(
chat,
currentMessage,
turnCounter++,
combinedSignal,
deadlineTimer.signal,
onWaitingForConfirmation,
);
if (turnResult.status === 'stop') {
terminateReason = turnResult.terminateReason;
// Only set finalResult if the turn provided one (e.g., error or goal).
if (turnResult.finalResult) {
finalResult = turnResult.finalResult;
}
break; // Exit the loop for *any* stop reason.
}
// If status is 'continue', update message for the next loop
currentMessage = turnResult.nextMessage;
// Check for new user steering hints collected via subscription
if (pendingHintsQueue.length > 0) {
const hintsToProcess = [...pendingHintsQueue];
pendingHintsQueue.length = 0;
const formattedHints = formatUserHintsForModel(hintsToProcess);
if (formattedHints) {
// Append hints to the current message (next turn)
currentMessage.parts ??= [];
currentMessage.parts.unshift({ text: formattedHints });
}
}
}
} finally {
this.runtimeContext.userHintService.offUserHint(hintListener);
}
// === UNIFIED RECOVERY BLOCK ===

View File

@@ -148,4 +148,129 @@ describe('SubAgentInvocation', () => {
updateOutput,
);
});
describe('withUserHints', () => {
it('should NOT modify query for local agents', async () => {
mockConfig = makeFakeConfig({ modelSteering: true });
mockConfig.userHintService.addUserHint('Test Hint');
const tool = new SubagentTool(testDefinition, mockConfig, mockMessageBus);
const params = { query: 'original query' };
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
// @ts-expect-error - accessing private method for testing
const hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toBe('original query');
});
it('should NOT modify query for remote agents if model steering is disabled', async () => {
mockConfig = makeFakeConfig({ modelSteering: false });
mockConfig.userHintService.addUserHint('Test Hint');
const tool = new SubagentTool(
testRemoteDefinition,
mockConfig,
mockMessageBus,
);
const params = { query: 'original query' };
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
// @ts-expect-error - accessing private method for testing
const hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toBe('original query');
});
it('should NOT modify query for remote agents if there are no hints', async () => {
mockConfig = makeFakeConfig({ modelSteering: true });
const tool = new SubagentTool(
testRemoteDefinition,
mockConfig,
mockMessageBus,
);
const params = { query: 'original query' };
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
// @ts-expect-error - accessing private method for testing
const hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toBe('original query');
});
it('should prepend hints to query for remote agents when hints exist and steering is enabled', async () => {
mockConfig = makeFakeConfig({ modelSteering: true });
const tool = new SubagentTool(
testRemoteDefinition,
mockConfig,
mockMessageBus,
);
const params = { query: 'original query' };
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
mockConfig.userHintService.addUserHint('Hint 1');
mockConfig.userHintService.addUserHint('Hint 2');
// @ts-expect-error - accessing private method for testing
const hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toContain('Hint 1');
expect(hintedParams.query).toContain('Hint 2');
expect(hintedParams.query).toMatch(/original query$/);
});
it('should NOT include legacy hints added before the invocation was created', async () => {
mockConfig = makeFakeConfig({ modelSteering: true });
mockConfig.userHintService.addUserHint('Legacy Hint');
const tool = new SubagentTool(
testRemoteDefinition,
mockConfig,
mockMessageBus,
);
const params = { query: 'original query' };
// Creation of invocation captures the current hint state
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
// Verify no hints are present yet
// @ts-expect-error - accessing private method for testing
let hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toBe('original query');
// Add a new hint after creation
mockConfig.userHintService.addUserHint('New Hint');
// @ts-expect-error - accessing private method for testing
hintedParams = invocation.withUserHints(params);
expect(hintedParams.query).toContain('New Hint');
expect(hintedParams.query).not.toContain('Legacy Hint');
});
it('should NOT modify query if query is missing or not a string', async () => {
mockConfig = makeFakeConfig({ modelSteering: true });
mockConfig.userHintService.addUserHint('Hint');
const tool = new SubagentTool(
testRemoteDefinition,
mockConfig,
mockMessageBus,
);
const params = { other: 'param' };
// @ts-expect-error - accessing private method for testing
const invocation = tool.createInvocation(params, mockMessageBus);
// @ts-expect-error - accessing private method for testing
const hintedParams = invocation.withUserHints(params);
expect(hintedParams).toEqual(params);
});
});
});

View File

@@ -18,6 +18,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { AgentDefinition, AgentInputs } from './types.js';
import { SubagentToolWrapper } from './subagent-tool-wrapper.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { formatUserHintsForModel } from '../utils/fastAckHelper.js';
export class SubagentTool extends BaseDeclarativeTool<AgentInputs, ToolResult> {
constructor(
@@ -65,6 +66,8 @@ export class SubagentTool extends BaseDeclarativeTool<AgentInputs, ToolResult> {
}
class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
private readonly startIndex: number;
constructor(
params: AgentInputs,
private readonly definition: AgentDefinition,
@@ -79,6 +82,7 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
_toolName ?? definition.name,
_toolDisplayName ?? definition.displayName ?? definition.name,
);
this.startIndex = config.userHintService.getLatestHintIndex();
}
getDescription(): string {
@@ -88,7 +92,10 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
override async shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const invocation = this.buildSubInvocation(this.definition, this.params);
const invocation = this.buildSubInvocation(
this.definition,
this.withUserHints(this.params),
);
return invocation.shouldConfirmExecute(abortSignal);
}
@@ -107,11 +114,38 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
);
}
const invocation = this.buildSubInvocation(this.definition, this.params);
const invocation = this.buildSubInvocation(
this.definition,
this.withUserHints(this.params),
);
return invocation.execute(signal, updateOutput);
}
private withUserHints(agentArgs: AgentInputs): AgentInputs {
if (this.definition.kind !== 'remote') {
return agentArgs;
}
const userHints = this.config.userHintService.getUserHintsAfter(
this.startIndex,
);
const formattedHints = formatUserHintsForModel(userHints);
if (!formattedHints) {
return agentArgs;
}
const query = agentArgs['query'];
if (typeof query !== 'string' || query.trim().length === 0) {
return agentArgs;
}
return {
...agentArgs,
query: `${formattedHints}\n\n${query}`,
};
}
private buildSubInvocation(
definition: AgentDefinition,
agentArgs: AgentInputs,