feat(core): experimental in-progress steering hints

This is a rebase / refactor of:
https://github.com/google-gemini/gemini-cli/pull/18783
This commit is contained in:
Your Name
2026-02-11 21:14:29 +00:00
parent ef02cec2cd
commit 5ed64c7130
45 changed files with 2090 additions and 136 deletions
@@ -0,0 +1,81 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, afterEach } from 'vitest';
import { AppRig } from '../test-utils/AppRig.js';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
import { PolicyDecision } from '@google/gemini-cli-core';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
describe('Model Steering Integration', () => {
let rig: AppRig | undefined;
afterEach(async () => {
await rig?.unmount();
});
it('should steer the model using a hint during a tool turn', async () => {
const fakeResponsesPath = path.join(
__dirname,
'../test-utils/fixtures/steering.responses',
);
rig = new AppRig({ fakeResponsesPath });
await rig.initialize();
rig.render();
await rig.waitForIdle();
rig.setToolPolicy('list_directory', PolicyDecision.ASK_USER);
rig.setToolPolicy('read_file', PolicyDecision.ASK_USER);
rig.setMockCommands([
{
command: /list_directory/,
result: {
output: 'file1.txt\nfile2.js\nfile3.md',
exitCode: 0,
},
},
{
command: /read_file file1.txt/,
result: {
output: 'This is file1.txt content.',
exitCode: 0,
},
},
]);
// Start a long task
await rig.type('Start long task');
await rig.pressEnter();
// Wait for the model to call 'list_directory' (Confirming state)
await rig.waitForOutput('ReadFolder');
// Injected a hint while the model is in a tool turn
await rig.addUserHint('focus on .txt');
// Resolve list_directory (Proceed)
await rig.resolveTool('ReadFolder');
// Wait for the model to process the hint and output the next action
// Based on steering.responses, it should first acknowledge the hint
await rig.waitForOutput('ACK: I will focus on .txt files now.');
// Then it should proceed with the next action
await rig.waitForOutput(
/Since you want me to focus on .txt files,[\s\S]*I will read file1.txt/,
);
await rig.waitForOutput('ReadFile');
// Resolve read_file (Proceed)
await rig.resolveTool('ReadFile');
// Wait for final completion
await rig.waitForOutput('Task complete.');
});
});
@@ -0,0 +1,80 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, afterEach, expect } from 'vitest';
import { AppRig } from './AppRig.js';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
import { debugLogger } from '@google/gemini-cli-core';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
describe('AppRig', () => {
let rig: AppRig | undefined;
afterEach(async () => {
await rig?.unmount();
});
it('should handle deterministic tool turns with breakpoints', async () => {
const fakeResponsesPath = path.join(
__dirname,
'fixtures',
'steering.responses',
);
rig = new AppRig({ fakeResponsesPath });
await rig.initialize();
rig.render();
await rig.waitForIdle();
// Set breakpoints on the canonical tool names
rig.setBreakpoint('list_directory');
rig.setBreakpoint('read_file');
// Start a task
debugLogger.log('[Test] Sending message: Start long task');
await rig.sendMessage('Start long task');
// Wait for the first breakpoint (list_directory)
const pending1 = await rig.waitForPendingConfirmation('list_directory');
expect(pending1.toolName).toBe('list_directory');
// Injected a hint
await rig.addUserHint('focus on .txt');
// Resolve and wait for the NEXT breakpoint (read_file)
// resolveTool will automatically remove the breakpoint policy for list_directory
await rig.resolveTool('list_directory');
const pending2 = await rig.waitForPendingConfirmation('read_file');
expect(pending2.toolName).toBe('read_file');
// Resolve and finish. Also removes read_file breakpoint.
await rig.resolveTool('read_file');
await rig.waitForOutput('Task complete.', 100000);
});
it('should render the app and handle a simple message', async () => {
const fakeResponsesPath = path.join(
__dirname,
'fixtures',
'simple.responses',
);
rig = new AppRig({ fakeResponsesPath });
await rig.initialize();
rig.render();
// Wait for initial render
await rig.waitForIdle();
// Type a message
await rig.type('Hello');
await rig.pressEnter();
// Wait for model response
await rig.waitForOutput('Hello! How can I help you today?');
});
});
+569
View File
@@ -0,0 +1,569 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import { act } from 'react';
import stripAnsi from 'strip-ansi';
import os from 'node:os';
import path from 'node:path';
import fs from 'node:fs';
import { AppContainer } from '../ui/AppContainer.js';
import { renderWithProviders } from './render.js';
import {
makeFakeConfig,
type Config,
type ConfigParameters,
ExtensionLoader,
AuthType,
ApprovalMode,
createPolicyEngineConfig,
PolicyDecision,
ToolConfirmationOutcome,
MessageBusType,
type ToolCallsUpdateMessage,
coreEvents,
ideContextStore,
createContentGenerator,
startupProfiler,
IdeClient,
debugLogger,
} from '@google/gemini-cli-core';
import {
type MockShellCommand,
MockShellExecutionService,
} from './MockShellExecutionService.js';
import { createMockSettings } from './settings.js';
import { type LoadedSettings } from '../config/settings.js';
import { AuthState } from '../ui/types.js';
// Mock core functions globally for tests using AppRig.
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@google/gemini-cli-core')>();
const { MockShellExecutionService: MockService } = await import(
'./MockShellExecutionService.js'
);
// Register the real execution logic so MockShellExecutionService can fall back to it
MockService.setOriginalImplementation(original.ShellExecutionService.execute);
return {
...original,
ShellExecutionService: MockService,
};
});
// Mock useAuthCommand to bypass authentication flows in tests
vi.mock('../ui/auth/useAuth.js', () => ({
useAuthCommand: () => ({
authState: AuthState.Authenticated,
setAuthState: vi.fn(),
authError: null,
onAuthError: vi.fn(),
apiKeyDefaultValue: 'test-api-key',
reloadApiKey: vi.fn().mockResolvedValue('test-api-key'),
}),
validateAuthMethodWithSettings: () => null,
}));
// A minimal mock ExtensionManager to satisfy AppContainer's forceful cast
class MockExtensionManager extends ExtensionLoader {
getExtensions = vi.fn().mockReturnValue([]);
setRequestConsent = vi.fn();
setRequestSetting = vi.fn();
}
export interface AppRigOptions {
fakeResponsesPath?: string;
terminalWidth?: number;
terminalHeight?: number;
configOverrides?: Partial<ConfigParameters>;
}
export interface PendingConfirmation {
toolName: string;
toolDisplayName?: string;
correlationId: string;
}
export class AppRig {
private renderResult: ReturnType<typeof renderWithProviders> | undefined;
private config: Config | undefined;
private settings: LoadedSettings | undefined;
private testDir: string;
private sessionId: string;
private pendingConfirmations = new Map<string, PendingConfirmation>();
private breakpointTools = new Set<string | undefined>();
private lastAwaitedConfirmation: PendingConfirmation | undefined;
constructor(private options: AppRigOptions = {}) {
this.testDir = fs.mkdtempSync(path.join(os.tmpdir(), 'gemini-app-rig-'));
this.sessionId = `test-session-${Math.random().toString(36).slice(2, 9)}`;
}
async initialize() {
this.setupEnvironment();
this.settings = this.createRigSettings();
const approvalMode =
this.options.configOverrides?.approvalMode ?? ApprovalMode.DEFAULT;
const policyEngineConfig = await createPolicyEngineConfig(
this.settings.merged,
approvalMode,
);
const configParams: ConfigParameters = {
sessionId: this.sessionId,
targetDir: this.testDir,
cwd: this.testDir,
debugMode: false,
model: 'test-model',
fakeResponses: this.options.fakeResponsesPath,
interactive: true,
approvalMode,
policyEngineConfig,
enableEventDrivenScheduler: true,
extensionLoader: new MockExtensionManager(),
excludeTools: this.options.configOverrides?.excludeTools,
...this.options.configOverrides,
};
this.config = makeFakeConfig(configParams);
if (this.options.fakeResponsesPath) {
this.stubRefreshAuth();
}
this.setupMessageBusListeners();
await act(async () => {
await this.config!.initialize();
// Since we mocked useAuthCommand, we must manually trigger the first
// refreshAuth to ensure contentGenerator is initialized.
await this.config!.refreshAuth(AuthType.USE_GEMINI);
});
}
private setupEnvironment() {
// Stub environment variables to avoid interference from developer's machine
vi.stubEnv('GEMINI_CLI_HOME', this.testDir);
if (this.options.fakeResponsesPath) {
vi.stubEnv('GEMINI_API_KEY', 'test-api-key');
MockShellExecutionService.setPassthrough(false);
} else {
if (!process.env['GEMINI_API_KEY']) {
throw new Error(
'GEMINI_API_KEY must be set in the environment for live model tests.',
);
}
// For live tests, we allow falling through to the real shell service if no mock matches
MockShellExecutionService.setPassthrough(true);
}
vi.stubEnv('GEMINI_DEFAULT_AUTH_TYPE', AuthType.USE_GEMINI);
}
private createRigSettings(): LoadedSettings {
return createMockSettings({
user: {
path: path.join(this.testDir, '.gemini', 'user_settings.json'),
settings: {
security: {
auth: {
selectedType: AuthType.USE_GEMINI,
useExternal: true,
},
folderTrust: {
enabled: true,
},
},
ide: {
enabled: false,
hasSeenNudge: true,
},
},
originalSettings: {},
},
merged: {
security: {
auth: {
selectedType: AuthType.USE_GEMINI,
useExternal: true,
},
folderTrust: {
enabled: true,
},
},
ide: {
enabled: false,
hasSeenNudge: true,
},
},
});
}
private stubRefreshAuth() {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
const gcConfig = this.config as any;
gcConfig.refreshAuth = async (authMethod: AuthType) => {
gcConfig.modelAvailabilityService.reset();
const newContentGeneratorConfig = {
authType: authMethod,
proxy: gcConfig.getProxy(),
apiKey: process.env['GEMINI_API_KEY'] || 'test-api-key',
};
gcConfig.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
this.config!,
gcConfig.getSessionId(),
);
gcConfig.contentGeneratorConfig = newContentGeneratorConfig;
// Initialize BaseLlmClient now that the ContentGenerator is available
const { BaseLlmClient } = await import('@google/gemini-cli-core');
gcConfig.baseLlmClient = new BaseLlmClient(
gcConfig.contentGenerator,
this.config!,
);
};
}
private setupMessageBusListeners() {
if (!this.config) return;
const messageBus = this.config.getMessageBus();
messageBus.subscribe(
MessageBusType.TOOL_CALLS_UPDATE,
(message: ToolCallsUpdateMessage) => {
for (const call of message.toolCalls) {
if (call.status === 'awaiting_approval' && call.correlationId) {
const details = call.confirmationDetails;
const title = 'title' in details ? details.title : '';
const toolDisplayName =
call.tool?.displayName || title.replace(/^Confirm:\s*/, '');
if (!this.pendingConfirmations.has(call.correlationId)) {
this.pendingConfirmations.set(call.correlationId, {
toolName: call.request.name,
toolDisplayName,
correlationId: call.correlationId,
});
}
} else if (call.status !== 'awaiting_approval') {
for (const [
correlationId,
pending,
] of this.pendingConfirmations.entries()) {
if (pending.toolName === call.request.name) {
this.pendingConfirmations.delete(correlationId);
break;
}
}
}
}
},
);
}
render() {
if (!this.config || !this.settings)
throw new Error('AppRig not initialized');
act(() => {
this.renderResult = renderWithProviders(
<AppContainer
config={this.config!}
version="test-version"
initializationResult={{
authError: null,
themeError: null,
shouldOpenAuthDialog: false,
geminiMdFileCount: 0,
}}
/>,
{
config: this.config!,
settings: this.settings!,
width: this.options.terminalWidth ?? 120,
useAlternateBuffer: false,
uiState: {
terminalHeight: this.options.terminalHeight ?? 40,
},
},
);
});
}
setMockCommands(commands: MockShellCommand[]) {
MockShellExecutionService.setMockCommands(commands);
}
setToolPolicy(
toolName: string | undefined,
decision: PolicyDecision,
priority = 10,
) {
if (!this.config) throw new Error('AppRig not initialized');
this.config.getPolicyEngine().addRule({
toolName,
decision,
priority,
source: 'AppRig Override',
});
}
setBreakpoint(toolName: string | string[] | undefined) {
if (Array.isArray(toolName)) {
for (const name of toolName) {
this.setBreakpoint(name);
}
} else {
this.setToolPolicy(toolName, PolicyDecision.ASK_USER, 100);
this.breakpointTools.add(toolName);
}
}
removeToolPolicy(toolName?: string, source = 'AppRig Override') {
if (!this.config) throw new Error('AppRig not initialized');
this.config
.getPolicyEngine()
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
.removeRulesForTool(toolName as string, source);
this.breakpointTools.delete(toolName);
}
getTestDir(): string {
return this.testDir;
}
getPendingConfirmations() {
return Array.from(this.pendingConfirmations.values());
}
private async waitUntil(
predicate: () => boolean | Promise<boolean>,
options: { timeout?: number; interval?: number; message?: string } = {},
) {
const {
timeout = 30000,
interval = 100,
message = 'Condition timed out',
} = options;
const start = Date.now();
while (true) {
if (await predicate()) return;
if (Date.now() - start > timeout) {
throw new Error(message);
}
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, interval));
});
}
}
async waitForPendingConfirmation(
toolNameOrDisplayName?: string | RegExp,
timeout = 30000,
): Promise<PendingConfirmation> {
const matches = (p: PendingConfirmation) => {
if (!toolNameOrDisplayName) return true;
if (typeof toolNameOrDisplayName === 'string') {
return (
p.toolName === toolNameOrDisplayName ||
p.toolDisplayName === toolNameOrDisplayName
);
}
return (
toolNameOrDisplayName.test(p.toolName) ||
toolNameOrDisplayName.test(p.toolDisplayName || '')
);
};
let matched: PendingConfirmation | undefined;
await this.waitUntil(
() => {
matched = this.getPendingConfirmations().find(matches);
return !!matched;
},
{
timeout,
message: `Timed out waiting for pending confirmation: ${toolNameOrDisplayName || 'any'}. Current pending: ${this.getPendingConfirmations()
.map((p) => p.toolName)
.join(', ')}`,
},
);
this.lastAwaitedConfirmation = matched;
return matched!;
}
async resolveTool(
toolNameOrDisplayName: string | RegExp | PendingConfirmation,
outcome: ToolConfirmationOutcome = ToolConfirmationOutcome.ProceedOnce,
): Promise<void> {
if (!this.config) throw new Error('AppRig not initialized');
const messageBus = this.config.getMessageBus();
let pending: PendingConfirmation;
if (
typeof toolNameOrDisplayName === 'object' &&
'correlationId' in toolNameOrDisplayName
) {
pending = toolNameOrDisplayName;
} else {
pending = await this.waitForPendingConfirmation(toolNameOrDisplayName);
}
await act(async () => {
this.pendingConfirmations.delete(pending.correlationId);
if (this.breakpointTools.has(pending.toolName)) {
this.removeToolPolicy(pending.toolName);
}
// eslint-disable-next-line @typescript-eslint/no-floating-promises
messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: pending.correlationId,
confirmed: outcome !== ToolConfirmationOutcome.Cancel,
outcome,
});
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100));
});
}
async resolveAwaitedTool(
outcome: ToolConfirmationOutcome = ToolConfirmationOutcome.ProceedOnce,
): Promise<void> {
if (!this.lastAwaitedConfirmation) {
throw new Error('No tool has been awaited yet');
}
await this.resolveTool(this.lastAwaitedConfirmation, outcome);
this.lastAwaitedConfirmation = undefined;
}
async addUserHint(hint: string) {
if (!this.config) throw new Error('AppRig not initialized');
await act(async () => {
this.config!.addUserHint(hint);
});
}
getConfig(): Config {
if (!this.config) throw new Error('AppRig not initialized');
return this.config;
}
async type(text: string) {
if (!this.renderResult) throw new Error('AppRig not initialized');
await act(async () => {
this.renderResult!.stdin.write(text);
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
});
}
async pressEnter() {
await this.type('\r');
}
async pressKey(key: string) {
if (!this.renderResult) throw new Error('AppRig not initialized');
await act(async () => {
this.renderResult!.stdin.write(key);
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
});
}
get lastFrame() {
if (!this.renderResult) return '';
return stripAnsi(this.renderResult.lastFrame() || '');
}
getStaticOutput() {
if (!this.renderResult) return '';
return stripAnsi(this.renderResult.stdout.lastFrame() || '');
}
async waitForOutput(pattern: string | RegExp, timeout = 30000) {
await this.waitUntil(
() => {
const frame = this.lastFrame;
return typeof pattern === 'string'
? frame.includes(pattern)
: pattern.test(frame);
},
{
timeout,
message: `Timed out waiting for output: ${pattern}\nLast frame:\n${this.lastFrame}`,
},
);
}
async waitForIdle(timeout = 20000) {
await this.waitForOutput('Type your message', timeout);
}
async sendMessage(text: string) {
await this.type(text);
await this.pressEnter();
}
async unmount() {
// Poison the chat recording service to prevent late writes to the test directory
if (this.config) {
const recordingService = this.config
.getGeminiClient()
?.getChatRecordingService();
if (recordingService) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-type-assertion
(recordingService as any).conversationFile = null;
}
}
if (this.renderResult) {
this.renderResult.unmount();
}
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 500));
});
vi.unstubAllEnvs();
coreEvents.removeAllListeners();
coreEvents.drainBacklogs();
MockShellExecutionService.reset();
ideContextStore.clear();
// Forcefully clear IdeClient singleton promise
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-type-assertion
(IdeClient as any).instancePromise = null;
startupProfiler.clear();
vi.clearAllMocks();
this.config = undefined;
this.renderResult = undefined;
if (this.testDir && fs.existsSync(this.testDir)) {
try {
fs.rmSync(this.testDir, { recursive: true, force: true });
} catch (e) {
debugLogger.warn(
`Failed to cleanup test directory ${this.testDir}:`,
e,
);
}
}
}
}
@@ -0,0 +1,140 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type {
ShellExecutionHandle,
ShellExecutionResult,
ShellOutputEvent,
ShellExecutionConfig,
} from '@google/gemini-cli-core';
export interface MockShellCommand {
command: string | RegExp;
result: Partial<ShellExecutionResult>;
events?: ShellOutputEvent[];
}
type ShellExecutionServiceExecute = (
commandToExecute: string,
cwd: string,
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
shouldUseNodePty: boolean,
shellExecutionConfig: ShellExecutionConfig,
) => Promise<ShellExecutionHandle>;
export class MockShellExecutionService {
private static mockCommands: MockShellCommand[] = [];
private static originalExecute: ShellExecutionServiceExecute | undefined;
private static passthroughEnabled = false;
/**
* Registers the original implementation to allow falling back to real shell execution.
*/
static setOriginalImplementation(
implementation: ShellExecutionServiceExecute,
) {
this.originalExecute = implementation;
}
/**
* Enables or disables passthrough to the real implementation when no mock matches.
*/
static setPassthrough(enabled: boolean) {
this.passthroughEnabled = enabled;
}
static setMockCommands(commands: MockShellCommand[]) {
this.mockCommands = commands;
}
static reset() {
this.mockCommands = [];
this.passthroughEnabled = false;
this.writeToPty.mockClear();
this.kill.mockClear();
this.background.mockClear();
this.resizePty.mockClear();
this.scrollPty.mockClear();
}
static async execute(
commandToExecute: string,
cwd: string,
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
shouldUseNodePty: boolean,
shellExecutionConfig: ShellExecutionConfig,
): Promise<ShellExecutionHandle> {
const mock = this.mockCommands.find((m) =>
typeof m.command === 'string'
? m.command === commandToExecute
: m.command.test(commandToExecute),
);
const pid = Math.floor(Math.random() * 10000);
if (mock) {
if (mock.events) {
for (const event of mock.events) {
onOutputEvent(event);
}
}
const result: ShellExecutionResult = {
rawOutput: Buffer.from(mock.result.output || ''),
output: mock.result.output || '',
exitCode: mock.result.exitCode ?? 0,
signal: mock.result.signal ?? null,
error: mock.result.error ?? null,
aborted: false,
pid,
executionMethod: 'none',
...mock.result,
};
return {
pid,
result: Promise.resolve(result),
};
}
if (this.passthroughEnabled && this.originalExecute) {
return this.originalExecute(
commandToExecute,
cwd,
onOutputEvent,
abortSignal,
shouldUseNodePty,
shellExecutionConfig,
);
}
return {
pid,
result: Promise.resolve({
rawOutput: Buffer.from(''),
output: `Command not found: ${commandToExecute}`,
exitCode: 127,
signal: null,
error: null,
aborted: false,
pid,
executionMethod: 'none',
}),
};
}
static writeToPty = vi.fn();
static isPtyActive = vi.fn(() => false);
static onExit = vi.fn(() => () => {});
static kill = vi.fn();
static background = vi.fn();
static subscribe = vi.fn(() => () => {});
static resizePty = vi.fn();
static scrollPty = vi.fn();
}
@@ -0,0 +1 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP"}]}]}
@@ -0,0 +1,4 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"Starting a long task. First, I'll list the files."},{"functionCall":{"name":"list_directory","args":{"dir_path":"."}}}]},"finishReason":"STOP"}]}]}
{"method":"generateContent","response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ACK: I will focus on .txt files now."}]},"finishReason":"STOP"}]}}
{"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"I see the files. Since you want me to focus on .txt files, I will read file1.txt."},{"functionCall":{"name":"read_file","args":{"file_path":"file1.txt"}}}]},"finishReason":"STOP"}]}]}
{"method":"generateContentStream","response":[{"candidates":[{"content":{"role":"model","parts":[{"text":"I have read file1.txt. Task complete."}]},"finishReason":"STOP"}]}]}
+43 -32
View File
@@ -33,6 +33,7 @@ import { makeFakeConfig, type Config } from '@google/gemini-cli-core';
import { FakePersistentState } from './persistentStateFake.js';
import { AppContext, type AppState } from '../ui/contexts/AppContext.js';
import { createMockSettings } from './settings.js';
import { SessionStatsProvider } from '../ui/contexts/SessionContext.js';
export const persistentStateMock = new FakePersistentState();
@@ -160,6 +161,8 @@ const baseMockUiState = {
proQuotaRequest: null,
validationRequest: null,
},
hintMode: false,
hintBuffer: '',
};
export const mockAppState: AppState = {
@@ -209,6 +212,10 @@ const mockUIActions: UIActions = {
setActiveBackgroundShellPid: vi.fn(),
setIsBackgroundShellListOpen: vi.fn(),
setAuthContext: vi.fn(),
onHintInput: vi.fn(),
onHintBackspace: vi.fn(),
onHintClear: vi.fn(),
onHintSubmit: vi.fn(),
handleRestart: vi.fn(),
handleNewAgentsSelect: vi.fn(),
};
@@ -306,39 +313,43 @@ export const renderWithProviders = (
<UIStateContext.Provider value={finalUiState}>
<VimModeProvider settings={finalSettings}>
<ShellFocusContext.Provider value={shellFocus}>
<StreamingContext.Provider value={finalUiState.streamingState}>
<UIActionsContext.Provider value={finalUIActions}>
<ToolActionsProvider
config={config}
toolCalls={allToolCalls}
>
<AskUserActionsProvider
request={null}
onSubmit={vi.fn()}
onCancel={vi.fn()}
<SessionStatsProvider>
<StreamingContext.Provider
value={finalUiState.streamingState}
>
<UIActionsContext.Provider value={finalUIActions}>
<ToolActionsProvider
config={config}
toolCalls={allToolCalls}
>
<KeypressProvider>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
>
<TerminalProvider>
<ScrollProvider>
<Box
width={terminalWidth}
flexShrink={0}
flexGrow={0}
flexDirection="column"
>
{component}
</Box>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</AskUserActionsProvider>
</ToolActionsProvider>
</UIActionsContext.Provider>
</StreamingContext.Provider>
<AskUserActionsProvider
request={null}
onSubmit={vi.fn()}
onCancel={vi.fn()}
>
<KeypressProvider>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
>
<TerminalProvider>
<ScrollProvider>
<Box
width={terminalWidth}
flexShrink={0}
flexGrow={0}
flexDirection="column"
>
{component}
</Box>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</AskUserActionsProvider>
</ToolActionsProvider>
</UIActionsContext.Provider>
</StreamingContext.Provider>
</SessionStatsProvider>
</ShellFocusContext.Provider>
</VimModeProvider>
</UIStateContext.Provider>
+95
View File
@@ -94,6 +94,10 @@ import { basename } from 'node:path';
import { computeTerminalTitle } from '../utils/windowTitle.js';
import { useTextBuffer } from './components/shared/text-buffer.js';
import { useLogger } from './hooks/useLogger.js';
import {
buildUserSteeringHintPrompt,
generateSteeringAckMessage,
} from '@google/gemini-cli-core';
import { useGeminiStream } from './hooks/useGeminiStream.js';
import { type BackgroundShell } from './hooks/shellCommandProcessor.js';
import { useVim } from './hooks/vim.js';
@@ -603,6 +607,7 @@ export const AppContainer = (props: AppContainerProps) => {
apiKeyDefaultValue,
reloadApiKey,
} = useAuthCommand(settings, config, initializationResult.authError);
const [authContext, setAuthContext] = useState<{ requiresRestart?: boolean }>(
{},
);
@@ -963,6 +968,19 @@ Logging in with Google... Restarting Gemini CLI to continue.
}
}, [pendingRestorePrompt, inputHistory, historyManager.history]);
const lastProcessedHintIndexRef = useRef<number>(-1);
const consumePendingHints = useCallback(() => {
const userHints = config.getUserHintsAfter(
lastProcessedHintIndexRef.current,
);
if (userHints.length === 0) {
return null;
}
lastProcessedHintIndexRef.current = config.getLatestHintIndex();
return userHints.join('\n');
}, [config]);
const {
streamingState,
submitQuery,
@@ -1001,6 +1019,7 @@ Logging in with Google... Restarting Gemini CLI to continue.
terminalWidth,
terminalHeight,
embeddedShellFocused,
consumePendingHints,
);
toggleBackgroundShellRef.current = toggleBackgroundShell;
@@ -1103,10 +1122,38 @@ Logging in with Google... Restarting Gemini CLI to continue.
],
);
const handleHintSubmit = useCallback(
(hint: string) => {
const trimmed = hint.trim();
if (!trimmed) {
return;
}
config.addUserHint(trimmed);
// Render hints with a distinct style.
historyManager.addItem({
type: 'hint',
text: trimmed,
} as Omit<HistoryItem, 'id'>);
},
[config, historyManager],
);
const handleFinalSubmit = useCallback(
async (submittedValue: string) => {
const isSlash = isSlashCommand(submittedValue.trim());
const isIdle = streamingState === StreamingState.Idle;
const isAgentRunning =
streamingState === StreamingState.Responding ||
isToolExecuting([
...pendingSlashCommandHistoryItems,
...pendingGeminiHistoryItems,
]);
if (isAgentRunning && !isSlash) {
handleHintSubmit(submittedValue);
addInput(submittedValue);
return;
}
if (isSlash || (isIdle && isMcpReady)) {
if (!isSlash) {
@@ -1148,7 +1195,10 @@ Logging in with Google... Restarting Gemini CLI to continue.
isMcpReady,
streamingState,
messageQueue.length,
pendingSlashCommandHistoryItems,
pendingGeminiHistoryItems,
config,
handleHintSubmit,
],
);
@@ -1814,6 +1864,45 @@ Logging in with Google... Restarting Gemini CLI to continue.
[pendingSlashCommandHistoryItems, pendingGeminiHistoryItems],
);
useEffect(() => {
if (
!isConfigInitialized ||
streamingState !== StreamingState.Idle ||
!isMcpReady ||
isToolAwaitingConfirmation(pendingHistoryItems)
) {
return;
}
const pendingHint = consumePendingHints();
if (!pendingHint) {
return;
}
const geminiClient = config.getGeminiClient();
void generateSteeringAckMessage(geminiClient, pendingHint).then(
(ackText) => {
historyManager.addItem({
type: 'info',
icon: '· ',
color: 'gray',
marginBottom: 1,
text: ackText,
} as Omit<HistoryItem, 'id'>);
},
);
void submitQuery([{ text: buildUserSteeringHintPrompt(pendingHint) }]);
}, [
config,
historyManager,
isConfigInitialized,
isMcpReady,
streamingState,
submitQuery,
consumePendingHints,
pendingHistoryItems,
]);
const allToolCalls = useMemo(
() =>
pendingHistoryItems
@@ -1975,6 +2064,8 @@ Logging in with Google... Restarting Gemini CLI to continue.
isBackgroundShellListOpen,
adminSettingsChanged,
newAgents,
hintMode: false,
hintBuffer: '',
}),
[
isThemeDialogOpen,
@@ -2137,6 +2228,10 @@ Logging in with Google... Restarting Gemini CLI to continue.
setActiveBackgroundShellPid,
setIsBackgroundShellListOpen,
setAuthContext,
onHintInput: () => {},
onHintBackspace: () => {},
onHintClear: () => {},
onHintSubmit: () => {},
handleRestart: async () => {
if (process.send) {
const remoteSettings = config.getRemoteAdminSettings();
@@ -50,6 +50,7 @@ export const DialogManager = ({
const uiState = useUIState();
const uiActions = useUIActions();
const {
constrainHeight,
terminalHeight,
+2 -1
View File
@@ -71,7 +71,8 @@ export const Footer: React.FC = () => {
const justifyContent = hideCWD && hideModelInfo ? 'center' : 'space-between';
const displayVimMode = vimEnabled ? vimMode : undefined;
const showDebugProfiler = debugMode || isDevelopment;
const showDebugProfiler =
debugMode || (isDevelopment && settings.merged.general.devtools);
return (
<Box
@@ -96,6 +96,7 @@ describe('<Header />', () => {
},
background: {
primary: '',
hintMode: '',
diff: { added: '', removed: '' },
},
border: {
@@ -44,6 +44,18 @@ describe('<HistoryItemDisplay />', () => {
expect(lastFrame()).toContain('Hello');
});
it('renders HintMessage for "hint" type', () => {
const item: HistoryItem = {
...baseItem,
type: 'hint',
text: 'Try using ripgrep first',
};
const { lastFrame } = renderWithProviders(
<HistoryItemDisplay {...baseItem} item={item} />,
);
expect(lastFrame()).toContain('Try using ripgrep first');
});
it('renders UserMessage for "user" type with slash command', () => {
const item: HistoryItem = {
...baseItem,
@@ -35,6 +35,7 @@ import { ChatList } from './views/ChatList.js';
import { HooksList } from './views/HooksList.js';
import { ModelMessage } from './messages/ModelMessage.js';
import { ThinkingMessage } from './messages/ThinkingMessage.js';
import { HintMessage } from './messages/HintMessage.js';
import { getInlineThinkingMode } from '../utils/inlineThinkingMode.js';
import { useSettings } from '../contexts/SettingsContext.js';
@@ -71,6 +72,9 @@ export const HistoryItemDisplay: React.FC<HistoryItemDisplayProps> = ({
{itemForDisplay.type === 'thinking' && inlineThinkingMode !== 'off' && (
<ThinkingMessage thought={itemForDisplay.thought} />
)}
{itemForDisplay.type === 'hint' && (
<HintMessage text={itemForDisplay.text} />
)}
{itemForDisplay.type === 'user' && (
<UserMessage text={itemForDisplay.text} width={terminalWidth} />
)}
@@ -102,6 +106,7 @@ export const HistoryItemDisplay: React.FC<HistoryItemDisplayProps> = ({
text={itemForDisplay.text}
icon={itemForDisplay.icon}
color={itemForDisplay.color}
marginBottom={itemForDisplay.marginBottom}
/>
)}
{itemForDisplay.type === 'warning' && (
@@ -238,7 +238,7 @@ export const InputPrompt: React.FC<InputPromptProps> = ({
]);
const [expandedSuggestionIndex, setExpandedSuggestionIndex] =
useState<number>(-1);
const shellHistory = useShellHistory(config.getProjectRoot());
const shellHistory = useShellHistory(config.getProjectRoot(), config.storage);
const shellHistoryData = shellHistory.history;
const completion = useCommandCompletion({
@@ -0,0 +1,53 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type React from 'react';
import { Text, Box } from 'ink';
import { theme } from '../../semantic-colors.js';
import { SCREEN_READER_USER_PREFIX } from '../../textConstants.js';
import { HalfLinePaddedBox } from '../shared/HalfLinePaddedBox.js';
import { useConfig } from '../../contexts/ConfigContext.js';
interface HintMessageProps {
text: string;
}
export const HintMessage: React.FC<HintMessageProps> = ({ text }) => {
const prefix = '💡 ';
const prefixWidth = prefix.length;
const config = useConfig();
const useBackgroundColor = config.getUseBackgroundColor();
return (
<HalfLinePaddedBox
backgroundBaseColor={theme.text.accent}
backgroundOpacity={0.1}
useBackgroundColor={useBackgroundColor}
>
<Box
flexDirection="row"
paddingY={0}
marginY={useBackgroundColor ? 0 : 1}
paddingX={useBackgroundColor ? 1 : 0}
alignSelf="flex-start"
>
<Box width={prefixWidth} flexShrink={0}>
<Text
color={theme.text.accent}
aria-label={SCREEN_READER_USER_PREFIX}
>
{prefix}
</Text>
</Box>
<Box flexGrow={1}>
<Text wrap="wrap" italic color={theme.text.accent}>
{`Steering Hint: ${text}`}
</Text>
</Box>
</Box>
</HalfLinePaddedBox>
);
};
@@ -13,19 +13,21 @@ interface InfoMessageProps {
text: string;
icon?: string;
color?: string;
marginBottom?: number;
}
export const InfoMessage: React.FC<InfoMessageProps> = ({
text,
icon,
color,
marginBottom,
}) => {
color ??= theme.status.warning;
const prefix = icon ?? ' ';
const prefixWidth = prefix.length;
return (
<Box flexDirection="row" marginTop={1}>
<Box flexDirection="row" marginTop={1} marginBottom={marginBottom ?? 0}>
<Box width={prefixWidth}>
<Text color={color}>{prefix}</Text>
</Box>
@@ -73,6 +73,10 @@ export interface UIActions {
setActiveBackgroundShellPid: (pid: number) => void;
setIsBackgroundShellListOpen: (isOpen: boolean) => void;
setAuthContext: (context: { requiresRestart?: boolean }) => void;
onHintInput: (char: string) => void;
onHintBackspace: () => void;
onHintClear: () => void;
onHintSubmit: (hint: string) => void;
handleRestart: () => void;
handleNewAgentsSelect: (choice: NewAgentsChoice) => Promise<void>;
}
@@ -173,6 +173,8 @@ export interface UIState {
isBackgroundShellListOpen: boolean;
adminSettingsChanged: boolean;
newAgents: AgentDefinition[] | null;
hintMode: boolean;
hintBuffer: string;
transientMessage: {
text: string;
type: TransientMessageType;
@@ -56,6 +56,11 @@ const MockedGeminiClientClass = vi.hoisted(() =>
this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream;
this.addHistory = vi.fn();
this.generateContent = vi.fn().mockResolvedValue({
candidates: [
{ content: { parts: [{ text: 'Got it. Focusing on tests only.' }] } },
],
});
this.getCurrentSequenceModel = vi.fn().mockReturnValue('test-model');
this.getChat = vi.fn().mockReturnValue({
recordCompletedToolCalls: vi.fn(),
@@ -152,13 +157,17 @@ vi.mock('./useLogger.js', () => ({
const mockStartNewPrompt = vi.fn();
const mockAddUsage = vi.fn();
vi.mock('../contexts/SessionContext.js', () => ({
useSessionStats: vi.fn(() => ({
startNewPrompt: mockStartNewPrompt,
addUsage: mockAddUsage,
getPromptCount: vi.fn(() => 5),
})),
}));
vi.mock('../contexts/SessionContext.js', async (importOriginal) => {
const actual = (await importOriginal()) as any;
return {
...actual,
useSessionStats: vi.fn(() => ({
startNewPrompt: mockStartNewPrompt,
addUsage: mockAddUsage,
getPromptCount: vi.fn(() => 5),
})),
};
});
vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false),
@@ -661,6 +670,113 @@ describe('useGeminiStream', () => {
);
});
it('should inject steering hint prompt for continuation', async () => {
const toolCallResponseParts: Part[] = [{ text: 'tool final response' }];
const completedToolCalls: TrackedToolCall[] = [
{
request: {
callId: 'call1',
name: 'tool1',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-ack',
},
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'call1',
responseParts: toolCallResponseParts,
errorType: undefined,
},
tool: {
displayName: 'MockTool',
},
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall,
];
mockSendMessageStream.mockReturnValue(
(async function* () {
yield {
type: ServerGeminiEventType.Content,
value: 'Applied the requested adjustment.',
};
})(),
);
let capturedOnComplete:
| ((completedTools: TrackedToolCall[]) => Promise<void>)
| null = null;
mockUseToolScheduler.mockImplementation((onComplete) => {
capturedOnComplete = onComplete;
return [
[],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(),
mockCancelAllToolCalls,
0,
];
});
renderHookWithProviders(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
80,
24,
undefined,
() => 'focus on tests only',
),
);
await act(async () => {
if (capturedOnComplete) {
await new Promise((resolve) => setTimeout(resolve, 0));
await capturedOnComplete(completedToolCalls);
}
});
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
});
const sentParts = mockSendMessageStream.mock.calls[0][0] as Part[];
const injectedHintPart = sentParts[0] as { text?: string };
expect(injectedHintPart.text).toContain(
'User steering update: "focus on tests only"',
);
expect(injectedHintPart.text).toContain(
'Classify it as ADD_TASK, MODIFY_TASK, CANCEL_TASK, or EXTRA_CONTEXT.',
);
expect(injectedHintPart.text).toContain(
'Do not cancel/skip tasks unless the user explicitly cancels them.',
);
expect(
mockAddItem.mock.calls.some(
([item]) =>
item?.type === 'info' &&
typeof item.text === 'string' &&
item.text.includes('Got it. Focusing on tests only.'),
),
).toBe(true);
});
it('should handle all tool calls being cancelled', async () => {
const cancelledToolCalls: TrackedToolCall[] = [
{
@@ -32,6 +32,8 @@ import {
ValidationRequiredError,
coreEvents,
CoreEvent,
buildUserSteeringHintPrompt,
generateSteeringAckMessage,
} from '@google/gemini-cli-core';
import type {
Config,
@@ -81,6 +83,7 @@ import path from 'node:path';
import { useSessionStats } from '../contexts/SessionContext.js';
import { useKeypress } from './useKeypress.js';
import type { LoadedSettings } from '../../config/settings.js';
import { theme } from '../semantic-colors.js';
type ToolResponseWithParts = ToolCallResponseInfo & {
llmContent?: PartListUnion;
@@ -185,6 +188,7 @@ export const useGeminiStream = (
terminalWidth: number,
terminalHeight: number,
isShellFocused?: boolean,
consumeUserHint?: () => string | null,
) => {
const [initError, setInitError] = useState<string | null>(null);
const [retryStatus, setRetryStatus] = useState<RetryAttemptPayload | null>(
@@ -1561,6 +1565,28 @@ export const useGeminiStream = (
const responsesToSend: Part[] = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
if (consumeUserHint) {
const userHint = consumeUserHint();
if (userHint && userHint.trim().length > 0) {
const hintText = userHint.trim();
responsesToSend.unshift({
text: buildUserSteeringHintPrompt(hintText),
});
void generateSteeringAckMessage(geminiClient, hintText).then(
(ackText) => {
addItem({
type: 'info',
icon: '· ',
color: theme.text.secondary,
marginBottom: 1,
text: ackText,
} as Omit<HistoryItem, 'id'>);
},
);
}
}
const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId,
);
@@ -1593,6 +1619,7 @@ export const useGeminiStream = (
modelSwitchedFromQuotaError,
addItem,
registerBackgroundShell,
consumeUserHint,
],
);
+16 -4
View File
@@ -79,14 +79,26 @@ export function useShellHistory(
const [historyFilePath, setHistoryFilePath] = useState<string | null>(null);
useEffect(() => {
let isMounted = true;
async function loadHistory() {
const filePath = await getHistoryFilePath(projectRoot, storage);
setHistoryFilePath(filePath);
const loadedHistory = await readHistoryFile(filePath);
setHistory(loadedHistory.reverse()); // Newest first
try {
const filePath = await getHistoryFilePath(projectRoot, storage);
if (!isMounted) return;
setHistoryFilePath(filePath);
const loadedHistory = await readHistoryFile(filePath);
if (!isMounted) return;
setHistory(loadedHistory.reverse()); // Newest first
} catch (error) {
if (isMounted) {
debugLogger.error('Error loading shell history:', error);
}
}
}
// eslint-disable-next-line @typescript-eslint/no-floating-promises
loadHistory();
return () => {
isMounted = false;
};
}, [projectRoot, storage]);
const addCommandToHistory = useCallback(
+1
View File
@@ -36,6 +36,7 @@ const noColorSemanticColors: SemanticColors = {
},
background: {
primary: '',
hintMode: '',
diff: {
added: '',
removed: '',
@@ -16,6 +16,7 @@ export interface SemanticColors {
};
background: {
primary: string;
hintMode: string;
diff: {
added: string;
removed: string;
@@ -48,6 +49,7 @@ export const lightSemanticColors: SemanticColors = {
},
background: {
primary: lightTheme.Background,
hintMode: '#E8E0F0',
diff: {
added: lightTheme.DiffAdded,
removed: lightTheme.DiffRemoved,
@@ -80,6 +82,7 @@ export const darkSemanticColors: SemanticColors = {
},
background: {
primary: darkTheme.Background,
hintMode: '#352A45',
diff: {
added: darkTheme.DiffAdded,
removed: darkTheme.DiffRemoved,
+2
View File
@@ -131,6 +131,7 @@ export class Theme {
},
background: {
primary: this.colors.Background,
hintMode: this.type === 'light' ? '#E8E0F0' : '#352A45',
diff: {
added: this.colors.DiffAdded,
removed: this.colors.DiffRemoved,
@@ -400,6 +401,7 @@ export function createCustomTheme(customTheme: CustomTheme): Theme {
},
background: {
primary: customTheme.background?.primary ?? colors.Background,
hintMode: 'magenta',
diff: {
added: customTheme.background?.diff?.added ?? colors.DiffAdded,
removed: customTheme.background?.diff?.removed ?? colors.DiffRemoved,
+8
View File
@@ -123,6 +123,7 @@ export type HistoryItemInfo = HistoryItemBase & {
text: string;
icon?: string;
color?: string;
marginBottom?: number;
};
export type HistoryItemError = HistoryItemBase & {
@@ -225,6 +226,11 @@ export type HistoryItemThinking = HistoryItemBase & {
thought: ThoughtSummary;
};
export type HistoryItemHint = HistoryItemBase & {
type: 'hint';
text: string;
};
export type HistoryItemChatList = HistoryItemBase & {
type: 'chat_list';
chats: ChatDetail[];
@@ -349,6 +355,7 @@ export type HistoryItemWithoutId =
| HistoryItemMcpStatus
| HistoryItemChatList
| HistoryItemThinking
| HistoryItemHint
| HistoryItemHooksList;
export type HistoryItem = HistoryItemWithoutId & { id: number };
@@ -374,6 +381,7 @@ export enum MessageType {
MCP_STATUS = 'mcp_status',
CHAT_LIST = 'chat_list',
HOOKS_LIST = 'hooks_list',
HINT = 'hint',
}
// Simplified message structure for internal feedback