mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-17 17:41:24 -07:00
feat(core): experimental in-progress steering hints (2 of 2) (#19307)
This commit is contained in:
@@ -2037,6 +2037,215 @@ describe('LocalAgentExecutor', () => {
|
||||
expect(recoveryEvent.success).toBe(true);
|
||||
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
});
|
||||
|
||||
describe('Model Steering', () => {
|
||||
let configWithHints: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
configWithHints = makeFakeConfig({ modelSteering: true });
|
||||
vi.spyOn(configWithHints, 'getAgentRegistry').mockReturnValue({
|
||||
getAllAgentNames: () => [],
|
||||
} as unknown as AgentRegistry);
|
||||
vi.spyOn(configWithHints, 'getToolRegistry').mockReturnValue(
|
||||
parentToolRegistry,
|
||||
);
|
||||
});
|
||||
|
||||
it('should inject user hints into the next turn after they are added', async () => {
|
||||
const definition = createTestDefinition();
|
||||
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
configWithHints,
|
||||
);
|
||||
|
||||
// Turn 1: Model calls LS
|
||||
mockModelResponse(
|
||||
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
|
||||
'T1: Listing',
|
||||
);
|
||||
|
||||
// We use a manual promise to ensure the hint is added WHILE Turn 1 is "running"
|
||||
let resolveToolCall: (value: unknown) => void;
|
||||
const toolCallPromise = new Promise((resolve) => {
|
||||
resolveToolCall = resolve;
|
||||
});
|
||||
mockScheduleAgentTools.mockReturnValueOnce(toolCallPromise);
|
||||
|
||||
// Turn 2: Model calls complete_task
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call2',
|
||||
},
|
||||
],
|
||||
'T2: Done',
|
||||
);
|
||||
|
||||
const runPromise = executor.run({ goal: 'Hint test' }, signal);
|
||||
|
||||
// Give the loop a chance to start and register the listener
|
||||
await vi.advanceTimersByTimeAsync(1);
|
||||
|
||||
configWithHints.userHintService.addUserHint('Initial Hint');
|
||||
|
||||
// Resolve the tool call to complete Turn 1
|
||||
resolveToolCall!([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'file1.txt',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { result: 'file1.txt' },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
await runPromise;
|
||||
|
||||
// The first call to sendMessageStream should NOT contain the hint (it was added after start)
|
||||
// The SECOND call to sendMessageStream SHOULD contain the hint
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
const secondTurnMessageParts = mockSendMessageStream.mock.calls[1][1];
|
||||
expect(secondTurnMessageParts).toContainEqual(
|
||||
expect.objectContaining({
|
||||
text: expect.stringContaining('Initial Hint'),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT inject legacy hints added before executor was created', async () => {
|
||||
const definition = createTestDefinition();
|
||||
configWithHints.userHintService.addUserHint('Legacy Hint');
|
||||
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
configWithHints,
|
||||
);
|
||||
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
|
||||
await executor.run({ goal: 'Isolation test' }, signal);
|
||||
|
||||
// The first call to sendMessageStream should NOT contain the legacy hint
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
const firstTurnMessageParts = mockSendMessageStream.mock.calls[0][1];
|
||||
// We expect only the goal, no hints injected at turn start
|
||||
for (const part of firstTurnMessageParts) {
|
||||
if (part.text) {
|
||||
expect(part.text).not.toContain('Legacy Hint');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('should inject mid-execution hints into subsequent turns', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
configWithHints,
|
||||
);
|
||||
|
||||
// Turn 1: Model calls LS
|
||||
mockModelResponse(
|
||||
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
|
||||
'T1: Listing',
|
||||
);
|
||||
|
||||
// We use a manual promise to ensure the hint is added WHILE Turn 1 is "running"
|
||||
let resolveToolCall: (value: unknown) => void;
|
||||
const toolCallPromise = new Promise((resolve) => {
|
||||
resolveToolCall = resolve;
|
||||
});
|
||||
mockScheduleAgentTools.mockReturnValueOnce(toolCallPromise);
|
||||
|
||||
// Turn 2: Model calls complete_task
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call2',
|
||||
},
|
||||
],
|
||||
'T2: Done',
|
||||
);
|
||||
|
||||
// Start execution
|
||||
const runPromise = executor.run({ goal: 'Mid-turn hint test' }, signal);
|
||||
|
||||
// Small delay to ensure the run loop has reached the await and registered listener
|
||||
await vi.advanceTimersByTimeAsync(1);
|
||||
|
||||
// Add the hint while the tool call is pending
|
||||
configWithHints.userHintService.addUserHint('Corrective Hint');
|
||||
|
||||
// Now resolve the tool call to complete Turn 1
|
||||
resolveToolCall!([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'file1.txt',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { result: 'file1.txt' },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
await runPromise;
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// The second turn (turn 1) should contain the corrective hint.
|
||||
const secondTurnMessageParts = mockSendMessageStream.mock.calls[1][1];
|
||||
expect(secondTurnMessageParts).toContainEqual(
|
||||
expect.objectContaining({
|
||||
text: expect.stringContaining('Corrective Hint'),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
describe('Chat Compression', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
|
||||
@@ -60,6 +60,7 @@ import { getToolCallContext } from '../utils/toolCallContext.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
import { DeadlineTimer } from '../utils/deadlineTimer.js';
|
||||
import { LlmRole } from '../telemetry/types.js';
|
||||
import { formatUserHintsForModel } from '../utils/fastAckHelper.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -463,45 +464,82 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
const query = this.definition.promptConfig.query
|
||||
? templateString(this.definition.promptConfig.query, augmentedInputs)
|
||||
: DEFAULT_QUERY_STRING;
|
||||
let currentMessage: Content = { role: 'user', parts: [{ text: query }] };
|
||||
|
||||
while (true) {
|
||||
// Check for termination conditions like max turns.
|
||||
const reason = this.checkTermination(turnCounter, maxTurns);
|
||||
if (reason) {
|
||||
terminateReason = reason;
|
||||
break;
|
||||
}
|
||||
const pendingHintsQueue: string[] = [];
|
||||
const hintListener = (hint: string) => {
|
||||
pendingHintsQueue.push(hint);
|
||||
};
|
||||
// Capture the index of the last hint before starting to avoid re-injecting old hints.
|
||||
// NOTE: Hints added AFTER this point will be broadcast to all currently running
|
||||
// local agents via the listener below.
|
||||
const startIndex =
|
||||
this.runtimeContext.userHintService.getLatestHintIndex();
|
||||
this.runtimeContext.userHintService.onUserHint(hintListener);
|
||||
|
||||
// Check for timeout or external abort.
|
||||
if (combinedSignal.aborted) {
|
||||
// Determine which signal caused the abort.
|
||||
terminateReason = deadlineTimer.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const initialHints =
|
||||
this.runtimeContext.userHintService.getUserHintsAfter(startIndex);
|
||||
const formattedInitialHints = formatUserHintsForModel(initialHints);
|
||||
|
||||
const turnResult = await this.executeTurn(
|
||||
chat,
|
||||
currentMessage,
|
||||
turnCounter++,
|
||||
combinedSignal,
|
||||
deadlineTimer.signal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
let currentMessage: Content = formattedInitialHints
|
||||
? {
|
||||
role: 'user',
|
||||
parts: [{ text: formattedInitialHints }, { text: query }],
|
||||
}
|
||||
: { role: 'user', parts: [{ text: query }] };
|
||||
|
||||
if (turnResult.status === 'stop') {
|
||||
terminateReason = turnResult.terminateReason;
|
||||
// Only set finalResult if the turn provided one (e.g., error or goal).
|
||||
if (turnResult.finalResult) {
|
||||
finalResult = turnResult.finalResult;
|
||||
while (true) {
|
||||
// Check for termination conditions like max turns.
|
||||
const reason = this.checkTermination(turnCounter, maxTurns);
|
||||
if (reason) {
|
||||
terminateReason = reason;
|
||||
break;
|
||||
}
|
||||
break; // Exit the loop for *any* stop reason.
|
||||
}
|
||||
|
||||
// If status is 'continue', update message for the next loop
|
||||
currentMessage = turnResult.nextMessage;
|
||||
// Check for timeout or external abort.
|
||||
if (combinedSignal.aborted) {
|
||||
// Determine which signal caused the abort.
|
||||
terminateReason = deadlineTimer.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
|
||||
const turnResult = await this.executeTurn(
|
||||
chat,
|
||||
currentMessage,
|
||||
turnCounter++,
|
||||
combinedSignal,
|
||||
deadlineTimer.signal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (turnResult.status === 'stop') {
|
||||
terminateReason = turnResult.terminateReason;
|
||||
// Only set finalResult if the turn provided one (e.g., error or goal).
|
||||
if (turnResult.finalResult) {
|
||||
finalResult = turnResult.finalResult;
|
||||
}
|
||||
break; // Exit the loop for *any* stop reason.
|
||||
}
|
||||
|
||||
// If status is 'continue', update message for the next loop
|
||||
currentMessage = turnResult.nextMessage;
|
||||
|
||||
// Check for new user steering hints collected via subscription
|
||||
if (pendingHintsQueue.length > 0) {
|
||||
const hintsToProcess = [...pendingHintsQueue];
|
||||
pendingHintsQueue.length = 0;
|
||||
const formattedHints = formatUserHintsForModel(hintsToProcess);
|
||||
if (formattedHints) {
|
||||
// Append hints to the current message (next turn)
|
||||
currentMessage.parts ??= [];
|
||||
currentMessage.parts.unshift({ text: formattedHints });
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
this.runtimeContext.userHintService.offUserHint(hintListener);
|
||||
}
|
||||
|
||||
// === UNIFIED RECOVERY BLOCK ===
|
||||
|
||||
@@ -148,4 +148,129 @@ describe('SubAgentInvocation', () => {
|
||||
updateOutput,
|
||||
);
|
||||
});
|
||||
|
||||
describe('withUserHints', () => {
|
||||
it('should NOT modify query for local agents', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: true });
|
||||
mockConfig.userHintService.addUserHint('Test Hint');
|
||||
|
||||
const tool = new SubagentTool(testDefinition, mockConfig, mockMessageBus);
|
||||
const params = { query: 'original query' };
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams.query).toBe('original query');
|
||||
});
|
||||
|
||||
it('should NOT modify query for remote agents if model steering is disabled', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: false });
|
||||
mockConfig.userHintService.addUserHint('Test Hint');
|
||||
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { query: 'original query' };
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams.query).toBe('original query');
|
||||
});
|
||||
|
||||
it('should NOT modify query for remote agents if there are no hints', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: true });
|
||||
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { query: 'original query' };
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams.query).toBe('original query');
|
||||
});
|
||||
|
||||
it('should prepend hints to query for remote agents when hints exist and steering is enabled', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: true });
|
||||
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { query: 'original query' };
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
mockConfig.userHintService.addUserHint('Hint 1');
|
||||
mockConfig.userHintService.addUserHint('Hint 2');
|
||||
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams.query).toContain('Hint 1');
|
||||
expect(hintedParams.query).toContain('Hint 2');
|
||||
expect(hintedParams.query).toMatch(/original query$/);
|
||||
});
|
||||
|
||||
it('should NOT include legacy hints added before the invocation was created', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: true });
|
||||
mockConfig.userHintService.addUserHint('Legacy Hint');
|
||||
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { query: 'original query' };
|
||||
|
||||
// Creation of invocation captures the current hint state
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
// Verify no hints are present yet
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
let hintedParams = invocation.withUserHints(params);
|
||||
expect(hintedParams.query).toBe('original query');
|
||||
|
||||
// Add a new hint after creation
|
||||
mockConfig.userHintService.addUserHint('New Hint');
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams.query).toContain('New Hint');
|
||||
expect(hintedParams.query).not.toContain('Legacy Hint');
|
||||
});
|
||||
|
||||
it('should NOT modify query if query is missing or not a string', async () => {
|
||||
mockConfig = makeFakeConfig({ modelSteering: true });
|
||||
mockConfig.userHintService.addUserHint('Hint');
|
||||
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { other: 'param' };
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
|
||||
// @ts-expect-error - accessing private method for testing
|
||||
const hintedParams = invocation.withUserHints(params);
|
||||
|
||||
expect(hintedParams).toEqual(params);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,6 +18,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { AgentDefinition, AgentInputs } from './types.js';
|
||||
import { SubagentToolWrapper } from './subagent-tool-wrapper.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import { formatUserHintsForModel } from '../utils/fastAckHelper.js';
|
||||
|
||||
export class SubagentTool extends BaseDeclarativeTool<AgentInputs, ToolResult> {
|
||||
constructor(
|
||||
@@ -65,6 +66,8 @@ export class SubagentTool extends BaseDeclarativeTool<AgentInputs, ToolResult> {
|
||||
}
|
||||
|
||||
class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
private readonly startIndex: number;
|
||||
|
||||
constructor(
|
||||
params: AgentInputs,
|
||||
private readonly definition: AgentDefinition,
|
||||
@@ -79,6 +82,7 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
_toolName ?? definition.name,
|
||||
_toolDisplayName ?? definition.displayName ?? definition.name,
|
||||
);
|
||||
this.startIndex = config.userHintService.getLatestHintIndex();
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -88,7 +92,10 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const invocation = this.buildSubInvocation(this.definition, this.params);
|
||||
const invocation = this.buildSubInvocation(
|
||||
this.definition,
|
||||
this.withUserHints(this.params),
|
||||
);
|
||||
return invocation.shouldConfirmExecute(abortSignal);
|
||||
}
|
||||
|
||||
@@ -107,11 +114,38 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
);
|
||||
}
|
||||
|
||||
const invocation = this.buildSubInvocation(this.definition, this.params);
|
||||
const invocation = this.buildSubInvocation(
|
||||
this.definition,
|
||||
this.withUserHints(this.params),
|
||||
);
|
||||
|
||||
return invocation.execute(signal, updateOutput);
|
||||
}
|
||||
|
||||
private withUserHints(agentArgs: AgentInputs): AgentInputs {
|
||||
if (this.definition.kind !== 'remote') {
|
||||
return agentArgs;
|
||||
}
|
||||
|
||||
const userHints = this.config.userHintService.getUserHintsAfter(
|
||||
this.startIndex,
|
||||
);
|
||||
const formattedHints = formatUserHintsForModel(userHints);
|
||||
if (!formattedHints) {
|
||||
return agentArgs;
|
||||
}
|
||||
|
||||
const query = agentArgs['query'];
|
||||
if (typeof query !== 'string' || query.trim().length === 0) {
|
||||
return agentArgs;
|
||||
}
|
||||
|
||||
return {
|
||||
...agentArgs,
|
||||
query: `${formattedHints}\n\n${query}`,
|
||||
};
|
||||
}
|
||||
|
||||
private buildSubInvocation(
|
||||
definition: AgentDefinition,
|
||||
agentArgs: AgentInputs,
|
||||
|
||||
Reference in New Issue
Block a user