mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-16 09:01:17 -07:00
feat(core): experimental in-progress steering hints (2 of 2) (#19307)
This commit is contained in:
@@ -21,7 +21,10 @@ import {
|
||||
CoreToolCallStatus,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { Buffer } from 'node:buffer';
|
||||
import type { HistoryItem, IndividualToolCallDisplay } from '../types.js';
|
||||
import type {
|
||||
HistoryItemToolGroup,
|
||||
IndividualToolCallDisplay,
|
||||
} from '../types.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
|
||||
const REF_CONTENT_HEADER = `\n${REFERENCE_CONTENT_START}`;
|
||||
@@ -697,7 +700,7 @@ export async function handleAtCommand({
|
||||
{
|
||||
type: 'tool_group',
|
||||
tools: allDisplays,
|
||||
} as Omit<HistoryItem, 'id'>,
|
||||
} as HistoryItemToolGroup,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ const MockedGeminiClientClass = vi.hoisted(() =>
|
||||
this.startChat = mockStartChat;
|
||||
this.sendMessageStream = mockSendMessageStream;
|
||||
this.addHistory = vi.fn();
|
||||
this.generateContent = vi.fn().mockResolvedValue({
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'Got it. Focusing on tests only.' }] } },
|
||||
],
|
||||
});
|
||||
this.getCurrentSequenceModel = vi.fn().mockReturnValue('test-model');
|
||||
this.getChat = vi.fn().mockReturnValue({
|
||||
recordCompletedToolCalls: vi.fn(),
|
||||
@@ -264,6 +269,13 @@ describe('useGeminiStream', () => {
|
||||
getGlobalMemory: vi.fn(() => ''),
|
||||
getUserMemory: vi.fn(() => ''),
|
||||
getMessageBus: vi.fn(() => mockMessageBus),
|
||||
getBaseLlmClient: vi.fn(() => ({
|
||||
generateContent: vi.fn().mockResolvedValue({
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'Got it. Focusing on tests only.' }] } },
|
||||
],
|
||||
}),
|
||||
})),
|
||||
getIdeMode: vi.fn(() => false),
|
||||
getEnableHooks: vi.fn(() => false),
|
||||
} as unknown as Config;
|
||||
@@ -675,6 +687,114 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should inject steering hint prompt for continuation', async () => {
|
||||
const toolCallResponseParts: Part[] = [{ text: 'tool final response' }];
|
||||
const completedToolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: 'tool1',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-ack',
|
||||
},
|
||||
status: 'success',
|
||||
responseSubmittedToGemini: false,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
responseParts: toolCallResponseParts,
|
||||
errorType: undefined,
|
||||
},
|
||||
tool: {
|
||||
displayName: 'MockTool',
|
||||
},
|
||||
invocation: {
|
||||
getDescription: () => `Mock description`,
|
||||
} as unknown as AnyToolInvocation,
|
||||
} as TrackedCompletedToolCall,
|
||||
];
|
||||
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Applied the requested adjustment.',
|
||||
};
|
||||
})(),
|
||||
);
|
||||
|
||||
let capturedOnComplete:
|
||||
| ((completedTools: TrackedToolCall[]) => Promise<void>)
|
||||
| null = null;
|
||||
mockUseToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [
|
||||
[],
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(),
|
||||
mockCancelAllToolCalls,
|
||||
0,
|
||||
];
|
||||
});
|
||||
|
||||
renderHookWithProviders(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockLoadedSettings,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
80,
|
||||
24,
|
||||
undefined,
|
||||
() => 'focus on tests only',
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
if (capturedOnComplete) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
await capturedOnComplete(completedToolCalls);
|
||||
}
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const sentParts = mockSendMessageStream.mock.calls[0][0] as Part[];
|
||||
const injectedHintPart = sentParts[0] as { text?: string };
|
||||
expect(injectedHintPart.text).toContain('User steering update:');
|
||||
expect(injectedHintPart.text).toContain(
|
||||
'<user_input>\nfocus on tests only\n</user_input>',
|
||||
);
|
||||
expect(injectedHintPart.text).toContain(
|
||||
'Classify it as ADD_TASK, MODIFY_TASK, CANCEL_TASK, or EXTRA_CONTEXT.',
|
||||
);
|
||||
expect(injectedHintPart.text).toContain(
|
||||
'Do not cancel/skip tasks unless the user explicitly cancels them.',
|
||||
);
|
||||
expect(
|
||||
mockAddItem.mock.calls.some(
|
||||
([item]) =>
|
||||
item?.type === 'info' &&
|
||||
typeof item.text === 'string' &&
|
||||
item.text.includes('Got it. Focusing on tests only.'),
|
||||
),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle all tool calls being cancelled', async () => {
|
||||
const cancelledToolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
|
||||
@@ -34,6 +34,8 @@ import {
|
||||
coreEvents,
|
||||
CoreEvent,
|
||||
CoreToolCallStatus,
|
||||
buildUserSteeringHintPrompt,
|
||||
generateSteeringAckMessage,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type {
|
||||
Config,
|
||||
@@ -55,6 +57,7 @@ import type {
|
||||
HistoryItemThinking,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItemToolGroup,
|
||||
HistoryItemInfo,
|
||||
IndividualToolCallDisplay,
|
||||
SlashCommandProcessorResult,
|
||||
HistoryItemModel,
|
||||
@@ -191,6 +194,7 @@ export const useGeminiStream = (
|
||||
terminalWidth: number,
|
||||
terminalHeight: number,
|
||||
isShellFocused?: boolean,
|
||||
consumeUserHint?: () => string | null,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const [retryStatus, setRetryStatus] = useState<RetryAttemptPayload | null>(
|
||||
@@ -1604,6 +1608,29 @@ export const useGeminiStream = (
|
||||
const responsesToSend: Part[] = geminiTools.flatMap(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
);
|
||||
|
||||
if (consumeUserHint) {
|
||||
const userHint = consumeUserHint();
|
||||
if (userHint && userHint.trim().length > 0) {
|
||||
const hintText = userHint.trim();
|
||||
responsesToSend.unshift({
|
||||
text: buildUserSteeringHintPrompt(hintText),
|
||||
});
|
||||
void generateSteeringAckMessage(
|
||||
config.getBaseLlmClient(),
|
||||
hintText,
|
||||
).then((ackText) => {
|
||||
addItem({
|
||||
type: 'info',
|
||||
icon: '· ',
|
||||
color: theme.text.secondary,
|
||||
marginBottom: 1,
|
||||
text: ackText,
|
||||
} as HistoryItemInfo);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const callIdsToMarkAsSubmitted = geminiTools.map(
|
||||
(toolCall) => toolCall.request.callId,
|
||||
);
|
||||
@@ -1636,6 +1663,8 @@ export const useGeminiStream = (
|
||||
modelSwitchedFromQuotaError,
|
||||
addItem,
|
||||
registerBackgroundShell,
|
||||
consumeUserHint,
|
||||
config,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user