feat(core): experimental in-progress steering hints (1 of 3) (#19008)

This commit is contained in:
joshualitt
2026-02-17 14:59:33 -08:00
committed by GitHub
parent 5e2f5df62c
commit 55c628e967
20 changed files with 1381 additions and 60 deletions
@@ -0,0 +1,41 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, afterEach } from 'vitest';
import { AppRig } from './AppRig.js';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
describe('AppRig', () => {
let rig: AppRig | undefined;
afterEach(async () => {
await rig?.unmount();
});
it('should render the app and handle a simple message', async () => {
const fakeResponsesPath = path.join(
__dirname,
'fixtures',
'simple.responses',
);
rig = new AppRig({ fakeResponsesPath });
await rig.initialize();
rig.render();
// Wait for initial render
await rig.waitForIdle();
// Type a message
await rig.type('Hello');
await rig.pressEnter();
// Wait for model response
await rig.waitForOutput('Hello! How can I help you today?');
});
});
+568
View File
@@ -0,0 +1,568 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import { act } from 'react';
import stripAnsi from 'strip-ansi';
import os from 'node:os';
import path from 'node:path';
import fs from 'node:fs';
import { AppContainer } from '../ui/AppContainer.js';
import { renderWithProviders } from './render.js';
import {
makeFakeConfig,
type Config,
type ConfigParameters,
ExtensionLoader,
AuthType,
ApprovalMode,
createPolicyEngineConfig,
PolicyDecision,
ToolConfirmationOutcome,
MessageBusType,
type ToolCallsUpdateMessage,
coreEvents,
ideContextStore,
createContentGenerator,
IdeClient,
debugLogger,
} from '@google/gemini-cli-core';
import {
type MockShellCommand,
MockShellExecutionService,
} from './MockShellExecutionService.js';
import { createMockSettings } from './settings.js';
import { type LoadedSettings } from '../config/settings.js';
import { AuthState } from '../ui/types.js';
// Mock core functions globally for tests using AppRig.
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@google/gemini-cli-core')>();
const { MockShellExecutionService: MockService } = await import(
'./MockShellExecutionService.js'
);
// Register the real execution logic so MockShellExecutionService can fall back to it
MockService.setOriginalImplementation(original.ShellExecutionService.execute);
return {
...original,
ShellExecutionService: MockService,
};
});
// Mock useAuthCommand to bypass authentication flows in tests
vi.mock('../ui/auth/useAuth.js', () => ({
useAuthCommand: () => ({
authState: AuthState.Authenticated,
setAuthState: vi.fn(),
authError: null,
onAuthError: vi.fn(),
apiKeyDefaultValue: 'test-api-key',
reloadApiKey: vi.fn().mockResolvedValue('test-api-key'),
}),
validateAuthMethodWithSettings: () => null,
}));
// A minimal mock ExtensionManager to satisfy AppContainer's forceful cast
class MockExtensionManager extends ExtensionLoader {
getExtensions = vi.fn().mockReturnValue([]);
setRequestConsent = vi.fn();
setRequestSetting = vi.fn();
}
export interface AppRigOptions {
fakeResponsesPath?: string;
terminalWidth?: number;
terminalHeight?: number;
configOverrides?: Partial<ConfigParameters>;
}
export interface PendingConfirmation {
toolName: string;
toolDisplayName?: string;
correlationId: string;
}
export class AppRig {
private renderResult: ReturnType<typeof renderWithProviders> | undefined;
private config: Config | undefined;
private settings: LoadedSettings | undefined;
private testDir: string;
private sessionId: string;
private pendingConfirmations = new Map<string, PendingConfirmation>();
private breakpointTools = new Set<string | undefined>();
private lastAwaitedConfirmation: PendingConfirmation | undefined;
constructor(private options: AppRigOptions = {}) {
this.testDir = fs.mkdtempSync(path.join(os.tmpdir(), 'gemini-app-rig-'));
this.sessionId = `test-session-${Math.random().toString(36).slice(2, 9)}`;
}
async initialize() {
this.setupEnvironment();
this.settings = this.createRigSettings();
const approvalMode =
this.options.configOverrides?.approvalMode ?? ApprovalMode.DEFAULT;
const policyEngineConfig = await createPolicyEngineConfig(
this.settings.merged,
approvalMode,
);
const configParams: ConfigParameters = {
sessionId: this.sessionId,
targetDir: this.testDir,
cwd: this.testDir,
debugMode: false,
model: 'test-model',
fakeResponses: this.options.fakeResponsesPath,
interactive: true,
approvalMode,
policyEngineConfig,
enableEventDrivenScheduler: true,
extensionLoader: new MockExtensionManager(),
excludeTools: this.options.configOverrides?.excludeTools,
...this.options.configOverrides,
};
this.config = makeFakeConfig(configParams);
if (this.options.fakeResponsesPath) {
this.stubRefreshAuth();
}
this.setupMessageBusListeners();
await act(async () => {
await this.config!.initialize();
// Since we mocked useAuthCommand, we must manually trigger the first
// refreshAuth to ensure contentGenerator is initialized.
await this.config!.refreshAuth(AuthType.USE_GEMINI);
});
}
private setupEnvironment() {
// Stub environment variables to avoid interference from developer's machine
vi.stubEnv('GEMINI_CLI_HOME', this.testDir);
if (this.options.fakeResponsesPath) {
vi.stubEnv('GEMINI_API_KEY', 'test-api-key');
MockShellExecutionService.setPassthrough(false);
} else {
if (!process.env['GEMINI_API_KEY']) {
throw new Error(
'GEMINI_API_KEY must be set in the environment for live model tests.',
);
}
// For live tests, we allow falling through to the real shell service if no mock matches
MockShellExecutionService.setPassthrough(true);
}
vi.stubEnv('GEMINI_DEFAULT_AUTH_TYPE', AuthType.USE_GEMINI);
}
private createRigSettings(): LoadedSettings {
return createMockSettings({
user: {
path: path.join(this.testDir, '.gemini', 'user_settings.json'),
settings: {
security: {
auth: {
selectedType: AuthType.USE_GEMINI,
useExternal: true,
},
folderTrust: {
enabled: true,
},
},
ide: {
enabled: false,
hasSeenNudge: true,
},
},
originalSettings: {},
},
merged: {
security: {
auth: {
selectedType: AuthType.USE_GEMINI,
useExternal: true,
},
folderTrust: {
enabled: true,
},
},
ide: {
enabled: false,
hasSeenNudge: true,
},
},
});
}
private stubRefreshAuth() {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
const gcConfig = this.config as any;
gcConfig.refreshAuth = async (authMethod: AuthType) => {
gcConfig.modelAvailabilityService.reset();
const newContentGeneratorConfig = {
authType: authMethod,
proxy: gcConfig.getProxy(),
apiKey: process.env['GEMINI_API_KEY'] || 'test-api-key',
};
gcConfig.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
this.config!,
gcConfig.getSessionId(),
);
gcConfig.contentGeneratorConfig = newContentGeneratorConfig;
// Initialize BaseLlmClient now that the ContentGenerator is available
const { BaseLlmClient } = await import('@google/gemini-cli-core');
gcConfig.baseLlmClient = new BaseLlmClient(
gcConfig.contentGenerator,
this.config!,
);
};
}
private setupMessageBusListeners() {
if (!this.config) return;
const messageBus = this.config.getMessageBus();
messageBus.subscribe(
MessageBusType.TOOL_CALLS_UPDATE,
(message: ToolCallsUpdateMessage) => {
for (const call of message.toolCalls) {
if (call.status === 'awaiting_approval' && call.correlationId) {
const details = call.confirmationDetails;
const title = 'title' in details ? details.title : '';
const toolDisplayName =
call.tool?.displayName || title.replace(/^Confirm:\s*/, '');
if (!this.pendingConfirmations.has(call.correlationId)) {
this.pendingConfirmations.set(call.correlationId, {
toolName: call.request.name,
toolDisplayName,
correlationId: call.correlationId,
});
}
} else if (call.status !== 'awaiting_approval') {
for (const [
correlationId,
pending,
] of this.pendingConfirmations.entries()) {
if (pending.toolName === call.request.name) {
this.pendingConfirmations.delete(correlationId);
break;
}
}
}
}
},
);
}
render() {
if (!this.config || !this.settings)
throw new Error('AppRig not initialized');
act(() => {
this.renderResult = renderWithProviders(
<AppContainer
config={this.config!}
version="test-version"
initializationResult={{
authError: null,
themeError: null,
shouldOpenAuthDialog: false,
geminiMdFileCount: 0,
}}
/>,
{
config: this.config!,
settings: this.settings!,
width: this.options.terminalWidth ?? 120,
useAlternateBuffer: false,
uiState: {
terminalHeight: this.options.terminalHeight ?? 40,
},
},
);
});
}
setMockCommands(commands: MockShellCommand[]) {
MockShellExecutionService.setMockCommands(commands);
}
setToolPolicy(
toolName: string | undefined,
decision: PolicyDecision,
priority = 10,
) {
if (!this.config) throw new Error('AppRig not initialized');
this.config.getPolicyEngine().addRule({
toolName,
decision,
priority,
source: 'AppRig Override',
});
}
setBreakpoint(toolName: string | string[] | undefined) {
if (Array.isArray(toolName)) {
for (const name of toolName) {
this.setBreakpoint(name);
}
} else {
this.setToolPolicy(toolName, PolicyDecision.ASK_USER, 100);
this.breakpointTools.add(toolName);
}
}
removeToolPolicy(toolName?: string, source = 'AppRig Override') {
if (!this.config) throw new Error('AppRig not initialized');
this.config
.getPolicyEngine()
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
.removeRulesForTool(toolName as string, source);
this.breakpointTools.delete(toolName);
}
getTestDir(): string {
return this.testDir;
}
getPendingConfirmations() {
return Array.from(this.pendingConfirmations.values());
}
private async waitUntil(
predicate: () => boolean | Promise<boolean>,
options: { timeout?: number; interval?: number; message?: string } = {},
) {
const {
timeout = 30000,
interval = 100,
message = 'Condition timed out',
} = options;
const start = Date.now();
while (true) {
if (await predicate()) return;
if (Date.now() - start > timeout) {
throw new Error(message);
}
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, interval));
});
}
}
async waitForPendingConfirmation(
toolNameOrDisplayName?: string | RegExp,
timeout = 30000,
): Promise<PendingConfirmation> {
const matches = (p: PendingConfirmation) => {
if (!toolNameOrDisplayName) return true;
if (typeof toolNameOrDisplayName === 'string') {
return (
p.toolName === toolNameOrDisplayName ||
p.toolDisplayName === toolNameOrDisplayName
);
}
return (
toolNameOrDisplayName.test(p.toolName) ||
toolNameOrDisplayName.test(p.toolDisplayName || '')
);
};
let matched: PendingConfirmation | undefined;
await this.waitUntil(
() => {
matched = this.getPendingConfirmations().find(matches);
return !!matched;
},
{
timeout,
message: `Timed out waiting for pending confirmation: ${toolNameOrDisplayName || 'any'}. Current pending: ${this.getPendingConfirmations()
.map((p) => p.toolName)
.join(', ')}`,
},
);
this.lastAwaitedConfirmation = matched;
return matched!;
}
async resolveTool(
toolNameOrDisplayName: string | RegExp | PendingConfirmation,
outcome: ToolConfirmationOutcome = ToolConfirmationOutcome.ProceedOnce,
): Promise<void> {
if (!this.config) throw new Error('AppRig not initialized');
const messageBus = this.config.getMessageBus();
let pending: PendingConfirmation;
if (
typeof toolNameOrDisplayName === 'object' &&
'correlationId' in toolNameOrDisplayName
) {
pending = toolNameOrDisplayName;
} else {
pending = await this.waitForPendingConfirmation(toolNameOrDisplayName);
}
await act(async () => {
this.pendingConfirmations.delete(pending.correlationId);
if (this.breakpointTools.has(pending.toolName)) {
this.removeToolPolicy(pending.toolName);
}
// eslint-disable-next-line @typescript-eslint/no-floating-promises
messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: pending.correlationId,
confirmed: outcome !== ToolConfirmationOutcome.Cancel,
outcome,
});
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100));
});
}
async resolveAwaitedTool(
outcome: ToolConfirmationOutcome = ToolConfirmationOutcome.ProceedOnce,
): Promise<void> {
if (!this.lastAwaitedConfirmation) {
throw new Error('No tool has been awaited yet');
}
await this.resolveTool(this.lastAwaitedConfirmation, outcome);
this.lastAwaitedConfirmation = undefined;
}
async addUserHint(_hint: string) {
if (!this.config) throw new Error('AppRig not initialized');
// TODO(joshualitt): Land hints.
// await act(async () => {
// this.config!.addUserHint(hint);
// });
}
getConfig(): Config {
if (!this.config) throw new Error('AppRig not initialized');
return this.config;
}
async type(text: string) {
if (!this.renderResult) throw new Error('AppRig not initialized');
await act(async () => {
this.renderResult!.stdin.write(text);
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
});
}
async pressEnter() {
await this.type('\r');
}
async pressKey(key: string) {
if (!this.renderResult) throw new Error('AppRig not initialized');
await act(async () => {
this.renderResult!.stdin.write(key);
});
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
});
}
get lastFrame() {
if (!this.renderResult) return '';
return stripAnsi(this.renderResult.lastFrame() || '');
}
getStaticOutput() {
if (!this.renderResult) return '';
return stripAnsi(this.renderResult.stdout.lastFrame() || '');
}
async waitForOutput(pattern: string | RegExp, timeout = 30000) {
await this.waitUntil(
() => {
const frame = this.lastFrame;
return typeof pattern === 'string'
? frame.includes(pattern)
: pattern.test(frame);
},
{
timeout,
message: `Timed out waiting for output: ${pattern}\nLast frame:\n${this.lastFrame}`,
},
);
}
async waitForIdle(timeout = 20000) {
await this.waitForOutput('Type your message', timeout);
}
async sendMessage(text: string) {
await this.type(text);
await this.pressEnter();
}
async unmount() {
// Poison the chat recording service to prevent late writes to the test directory
if (this.config) {
const recordingService = this.config
.getGeminiClient()
?.getChatRecordingService();
if (recordingService) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-type-assertion
(recordingService as any).conversationFile = null;
}
}
if (this.renderResult) {
this.renderResult.unmount();
}
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 500));
});
vi.unstubAllEnvs();
coreEvents.removeAllListeners();
coreEvents.drainBacklogs();
MockShellExecutionService.reset();
ideContextStore.clear();
// Forcefully clear IdeClient singleton promise
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-type-assertion
(IdeClient as any).instancePromise = null;
vi.clearAllMocks();
this.config = undefined;
this.renderResult = undefined;
if (this.testDir && fs.existsSync(this.testDir)) {
try {
fs.rmSync(this.testDir, { recursive: true, force: true });
} catch (e) {
debugLogger.warn(
`Failed to cleanup test directory ${this.testDir}:`,
e,
);
}
}
}
}
@@ -0,0 +1,140 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type {
ShellExecutionHandle,
ShellExecutionResult,
ShellOutputEvent,
ShellExecutionConfig,
} from '@google/gemini-cli-core';
export interface MockShellCommand {
command: string | RegExp;
result: Partial<ShellExecutionResult>;
events?: ShellOutputEvent[];
}
type ShellExecutionServiceExecute = (
commandToExecute: string,
cwd: string,
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
shouldUseNodePty: boolean,
shellExecutionConfig: ShellExecutionConfig,
) => Promise<ShellExecutionHandle>;
export class MockShellExecutionService {
private static mockCommands: MockShellCommand[] = [];
private static originalExecute: ShellExecutionServiceExecute | undefined;
private static passthroughEnabled = false;
/**
* Registers the original implementation to allow falling back to real shell execution.
*/
static setOriginalImplementation(
implementation: ShellExecutionServiceExecute,
) {
this.originalExecute = implementation;
}
/**
* Enables or disables passthrough to the real implementation when no mock matches.
*/
static setPassthrough(enabled: boolean) {
this.passthroughEnabled = enabled;
}
static setMockCommands(commands: MockShellCommand[]) {
this.mockCommands = commands;
}
static reset() {
this.mockCommands = [];
this.passthroughEnabled = false;
this.writeToPty.mockClear();
this.kill.mockClear();
this.background.mockClear();
this.resizePty.mockClear();
this.scrollPty.mockClear();
}
static async execute(
commandToExecute: string,
cwd: string,
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
shouldUseNodePty: boolean,
shellExecutionConfig: ShellExecutionConfig,
): Promise<ShellExecutionHandle> {
const mock = this.mockCommands.find((m) =>
typeof m.command === 'string'
? m.command === commandToExecute
: m.command.test(commandToExecute),
);
const pid = Math.floor(Math.random() * 10000);
if (mock) {
if (mock.events) {
for (const event of mock.events) {
onOutputEvent(event);
}
}
const result: ShellExecutionResult = {
rawOutput: Buffer.from(mock.result.output || ''),
output: mock.result.output || '',
exitCode: mock.result.exitCode ?? 0,
signal: mock.result.signal ?? null,
error: mock.result.error ?? null,
aborted: false,
pid,
executionMethod: 'none',
...mock.result,
};
return {
pid,
result: Promise.resolve(result),
};
}
if (this.passthroughEnabled && this.originalExecute) {
return this.originalExecute(
commandToExecute,
cwd,
onOutputEvent,
abortSignal,
shouldUseNodePty,
shellExecutionConfig,
);
}
return {
pid,
result: Promise.resolve({
rawOutput: Buffer.from(''),
output: `Command not found: ${commandToExecute}`,
exitCode: 127,
signal: null,
error: null,
aborted: false,
pid,
executionMethod: 'none',
}),
};
}
static writeToPty = vi.fn();
static isPtyActive = vi.fn(() => false);
static onExit = vi.fn(() => () => {});
static kill = vi.fn();
static background = vi.fn();
static subscribe = vi.fn(() => () => {});
static resizePty = vi.fn();
static scrollPty = vi.fn();
}
@@ -0,0 +1 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP"}]}]}
+37 -32
View File
@@ -33,6 +33,7 @@ import { makeFakeConfig, type Config } from '@google/gemini-cli-core';
import { FakePersistentState } from './persistentStateFake.js';
import { AppContext, type AppState } from '../ui/contexts/AppContext.js';
import { createMockSettings } from './settings.js';
import { SessionStatsProvider } from '../ui/contexts/SessionContext.js';
import { themeManager, DEFAULT_THEME } from '../ui/themes/theme-manager.js';
import { DefaultLight } from '../ui/themes/default-light.js';
import { pickDefaultThemeName } from '../ui/themes/theme.js';
@@ -324,39 +325,43 @@ export const renderWithProviders = (
<UIStateContext.Provider value={finalUiState}>
<VimModeProvider settings={finalSettings}>
<ShellFocusContext.Provider value={shellFocus}>
<StreamingContext.Provider value={finalUiState.streamingState}>
<UIActionsContext.Provider value={finalUIActions}>
<ToolActionsProvider
config={config}
toolCalls={allToolCalls}
>
<AskUserActionsProvider
request={null}
onSubmit={vi.fn()}
onCancel={vi.fn()}
<SessionStatsProvider>
<StreamingContext.Provider
value={finalUiState.streamingState}
>
<UIActionsContext.Provider value={finalUIActions}>
<ToolActionsProvider
config={config}
toolCalls={allToolCalls}
>
<KeypressProvider>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
>
<TerminalProvider>
<ScrollProvider>
<Box
width={terminalWidth}
flexShrink={0}
flexGrow={0}
flexDirection="column"
>
{component}
</Box>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</AskUserActionsProvider>
</ToolActionsProvider>
</UIActionsContext.Provider>
</StreamingContext.Provider>
<AskUserActionsProvider
request={null}
onSubmit={vi.fn()}
onCancel={vi.fn()}
>
<KeypressProvider>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
>
<TerminalProvider>
<ScrollProvider>
<Box
width={terminalWidth}
flexShrink={0}
flexGrow={0}
flexDirection="column"
>
{component}
</Box>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</AskUserActionsProvider>
</ToolActionsProvider>
</UIActionsContext.Provider>
</StreamingContext.Provider>
</SessionStatsProvider>
</ShellFocusContext.Provider>
</VimModeProvider>
</UIStateContext.Provider>
@@ -159,13 +159,17 @@ vi.mock('./useLogger.js', () => ({
const mockStartNewPrompt = vi.fn();
const mockAddUsage = vi.fn();
vi.mock('../contexts/SessionContext.js', () => ({
useSessionStats: vi.fn(() => ({
startNewPrompt: mockStartNewPrompt,
addUsage: mockAddUsage,
getPromptCount: vi.fn(() => 5),
})),
}));
vi.mock('../contexts/SessionContext.js', async (importOriginal) => {
const actual = (await importOriginal()) as any;
return {
...actual,
useSessionStats: vi.fn(() => ({
startNewPrompt: mockStartNewPrompt,
addUsage: mockAddUsage,
getPromptCount: vi.fn(() => 5),
})),
};
});
vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false),
+11 -5
View File
@@ -41,6 +41,7 @@ import type { SkillDefinition } from '../skills/skillLoader.js';
import type { McpClientManager } from '../tools/mcp-client-manager.js';
import { DEFAULT_MODEL_CONFIGS } from './defaultModelConfigs.js';
import { DEFAULT_GEMINI_MODEL } from './models.js';
import { Storage } from './storage.js';
vi.mock('fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('fs')>();
@@ -279,16 +280,21 @@ describe('Server Config (config.ts)', () => {
await expect(config.initialize()).resolves.toBeUndefined();
});
it('should throw an error if initialized more than once', async () => {
it('should deduplicate multiple calls to initialize', async () => {
const config = new Config({
...baseParams,
checkpointing: false,
});
await expect(config.initialize()).resolves.toBeUndefined();
await expect(config.initialize()).rejects.toThrow(
'Config was already initialized',
);
const storageSpy = vi.spyOn(Storage.prototype, 'initialize');
await Promise.all([
config.initialize(),
config.initialize(),
config.initialize(),
]);
expect(storageSpy).toHaveBeenCalledTimes(1);
});
it('should await MCP initialization in non-interactive mode', async () => {
+13 -6
View File
@@ -621,7 +621,8 @@ export class Config {
private readonly enablePromptCompletion: boolean = false;
private readonly truncateToolOutputThreshold: number;
private compressionTruncationCounter = 0;
private initialized: boolean = false;
private initialized = false;
private initPromise: Promise<void> | undefined;
readonly storage: Storage;
private readonly fileExclusions: FileExclusions;
private readonly eventEmitter?: EventEmitter;
@@ -674,7 +675,6 @@ export class Config {
private remoteAdminSettings: AdminControlsSettings | undefined;
private latestApiRequest: GenerateContentParameters | undefined;
private lastModeSwitchTime: number = Date.now();
private approvedPlanPath: string | undefined;
constructor(params: ConfigParameters) {
@@ -917,14 +917,20 @@ export class Config {
}
/**
* Must only be called once, throws if called again.
* Dedups initialization requests using a shared promise that is only resolved
* once.
*/
async initialize(): Promise<void> {
if (this.initialized) {
throw Error('Config was already initialized');
if (this.initPromise) {
return this.initPromise;
}
this.initialized = true;
this.initPromise = this._initialize();
return this.initPromise;
}
private async _initialize(): Promise<void> {
await this.storage.initialize();
// Add pending directories to workspace context
@@ -1011,6 +1017,7 @@ export class Config {
await this.geminiClient.initialize();
this.syncPlanModeTools();
this.initialized = true;
}
getContentGenerator(): ContentGenerator {
@@ -127,6 +127,19 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
},
},
},
'fast-ack-helper': {
extends: 'base',
modelConfig: {
model: 'gemini-2.5-flash-lite',
generateContentConfig: {
temperature: 0.2,
maxOutputTokens: 120,
thinkingConfig: {
thinkingBudget: 0,
},
},
},
},
'edit-corrector': {
extends: 'base',
modelConfig: {
+2
View File
@@ -28,6 +28,7 @@ export * from './commands/memory.js';
export * from './commands/types.js';
// Export Core Logic
export * from './core/baseLlmClient.js';
export * from './core/client.js';
export * from './core/contentGenerator.js';
export * from './core/loggingContentGenerator.js';
@@ -88,6 +89,7 @@ export * from './utils/formatters.js';
export * from './utils/generateContentResponseUtilities.js';
export * from './utils/filesearch/fileSearch.js';
export * from './utils/errorParsing.js';
export * from './utils/fastAckHelper.js';
export * from './utils/workspaceContext.js';
export * from './utils/environmentContext.js';
export * from './utils/ignorePatterns.js';
@@ -133,6 +133,17 @@
}
}
},
"fast-ack-helper": {
"model": "gemini-2.5-flash-lite",
"generateContentConfig": {
"temperature": 0.2,
"topP": 1,
"maxOutputTokens": 120,
"thinkingConfig": {
"thinkingBudget": 0
}
}
},
"edit-corrector": {
"model": "gemini-2.5-flash-lite",
"generateContentConfig": {
@@ -133,6 +133,17 @@
}
}
},
"fast-ack-helper": {
"model": "gemini-2.5-flash-lite",
"generateContentConfig": {
"temperature": 0.2,
"topP": 1,
"maxOutputTokens": 120,
"thinkingConfig": {
"thinkingBudget": 0
}
}
},
"edit-corrector": {
"model": "gemini-2.5-flash-lite",
"generateContentConfig": {
+1
View File
@@ -15,4 +15,5 @@ export enum LlmRole {
UTILITY_NEXT_SPEAKER = 'utility_next_speaker',
UTILITY_EDIT_CORRECTOR = 'utility_edit_corrector',
UTILITY_AUTOCOMPLETE = 'utility_autocomplete',
UTILITY_FAST_ACK_HELPER = 'utility_fast_ack_helper',
}
@@ -0,0 +1,146 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import {
DEFAULT_FAST_ACK_MODEL_CONFIG_KEY,
generateFastAckText,
truncateFastAckInput,
generateSteeringAckMessage,
} from './fastAckHelper.js';
import { LlmRole } from 'src/telemetry/llmRole.js';
describe('truncateFastAckInput', () => {
it('returns input as-is when below limit', () => {
expect(truncateFastAckInput('hello', 10)).toBe('hello');
});
it('truncates and appends suffix when above limit', () => {
const input = 'abcdefghijklmnopqrstuvwxyz';
const result = truncateFastAckInput(input, 20);
// grapheme count is 20
const segmenter = new Intl.Segmenter(undefined, {
granularity: 'grapheme',
});
expect(Array.from(segmenter.segment(result)).length).toBe(20);
expect(result).toContain('...[truncated]');
});
it('is grapheme aware', () => {
const input = '👨‍👩‍👧‍👦'.repeat(10); // 10 family emojis
const result = truncateFastAckInput(input, 5);
// family emoji is 1 grapheme
expect(result).toBe('👨‍👩‍👧‍👦👨‍👩‍👧‍👦👨‍👩‍👧‍👦👨‍👩‍👧‍👦👨‍👩‍👧‍👦');
});
});
describe('generateFastAckText', () => {
const abortSignal = new AbortController().signal;
it('uses the default fast-ack-helper model config and returns response text', async () => {
const llmClient = {
generateContent: vi.fn().mockResolvedValue({
candidates: [
{ content: { parts: [{ text: ' Got it. Skipping #2. ' }] } },
],
}),
} as unknown as BaseLlmClient;
const result = await generateFastAckText(llmClient, {
instruction: 'Write a short acknowledgement sentence.',
input: 'skip #2',
fallbackText: 'Got it.',
abortSignal,
promptId: 'test',
});
expect(result).toBe('Got it. Skipping #2.');
expect(llmClient.generateContent).toHaveBeenCalledWith({
modelConfigKey: DEFAULT_FAST_ACK_MODEL_CONFIG_KEY,
contents: expect.any(Array),
abortSignal,
promptId: 'test',
maxAttempts: 1,
role: LlmRole.UTILITY_FAST_ACK_HELPER,
});
});
it('returns fallback text when response text is empty', async () => {
const llmClient = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as BaseLlmClient;
const result = await generateFastAckText(llmClient, {
instruction: 'Return one sentence.',
input: 'cancel task 2',
fallbackText: 'Understood. Cancelling task 2.',
abortSignal,
promptId: 'test',
});
expect(result).toBe('Understood. Cancelling task 2.');
});
it('returns fallback text when generation throws', async () => {
const llmClient = {
generateContent: vi.fn().mockRejectedValue(new Error('boom')),
} as unknown as BaseLlmClient;
const result = await generateFastAckText(llmClient, {
instruction: 'Return one sentence.',
input: 'cancel task 2',
fallbackText: 'Understood.',
abortSignal,
promptId: 'test',
});
expect(result).toBe('Understood.');
});
});
describe('generateSteeringAckMessage', () => {
it('returns a shortened acknowledgement using fast-ack-helper', async () => {
const llmClient = {
generateContent: vi.fn().mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: 'Got it. I will focus on the tests now.' }],
},
},
],
}),
} as unknown as BaseLlmClient;
const result = await generateSteeringAckMessage(
llmClient,
'focus on tests',
);
expect(result).toBe('Got it. I will focus on the tests now.');
});
it('returns a fallback message if the model fails', async () => {
const llmClient = {
generateContent: vi.fn().mockRejectedValue(new Error('timeout')),
} as unknown as BaseLlmClient;
const result = await generateSteeringAckMessage(
llmClient,
'a very long hint that should be truncated in the fallback message if it was longer but it is not',
);
expect(result).toContain('Understood. a very long hint');
});
it('returns a very simple fallback if hint is empty', async () => {
const llmClient = {
generateContent: vi.fn().mockRejectedValue(new Error('error')),
} as unknown as BaseLlmClient;
const result = await generateSteeringAckMessage(llmClient, ' ');
expect(result).toBe('Understood. Adjusting the plan.');
});
});
+199
View File
@@ -0,0 +1,199 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { LlmRole } from '../telemetry/llmRole.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { debugLogger } from './debugLogger.js';
import { getResponseText } from './partUtils.js';
export const DEFAULT_FAST_ACK_MODEL_CONFIG_KEY: ModelConfigKey = {
model: 'fast-ack-helper',
};
export const DEFAULT_MAX_INPUT_CHARS = 1200;
export const DEFAULT_MAX_OUTPUT_CHARS = 180;
const INPUT_TRUNCATION_SUFFIX = '\n...[truncated]';
/**
* Normalizes whitespace in a string and trims it.
*/
export function normalizeSpace(text: string): string {
return text.replace(/\s+/g, ' ').trim();
}
/**
* Grapheme-aware slice.
*/
function safeSlice(text: string, start: number, end?: number): string {
const segmenter = new Intl.Segmenter(undefined, { granularity: 'grapheme' });
const segments = Array.from(segmenter.segment(text));
return segments
.slice(start, end)
.map((s) => s.segment)
.join('');
}
/**
* Grapheme-aware length.
*/
function safeLength(text: string): number {
const segmenter = new Intl.Segmenter(undefined, { granularity: 'grapheme' });
let count = 0;
for (const _ of segmenter.segment(text)) {
count++;
}
return count;
}
export const USER_STEERING_INSTRUCTION =
'Internal instruction: Re-evaluate the active plan using this user steering update. ' +
'Classify it as ADD_TASK, MODIFY_TASK, CANCEL_TASK, or EXTRA_CONTEXT. ' +
'Apply minimal-diff changes only to affected tasks and keep unaffected tasks active. ' +
'Do not cancel/skip tasks unless the user explicitly cancels them. ' +
'Acknowledge the steering briefly and state the course correction.';
/**
* Wraps user input in XML-like tags to mitigate prompt injection.
*/
function wrapInput(input: string): string {
return `<user_input>\n${input}\n</user_input>`;
}
export function buildUserSteeringHintPrompt(hintText: string): string {
const cleanHint = normalizeSpace(hintText);
return `User steering update:\n${wrapInput(cleanHint)}\n${USER_STEERING_INSTRUCTION}`;
}
export function formatUserHintsForModel(hints: string[]): string | null {
if (hints.length === 0) {
return null;
}
const hintText = hints.map((hint) => `- ${normalizeSpace(hint)}`).join('\n');
return `User hints:\n${wrapInput(hintText)}\n\n${USER_STEERING_INSTRUCTION}`;
}
const STEERING_ACK_INSTRUCTION =
'Write one short, friendly sentence acknowledging a user steering update for an in-progress task. ' +
'Be concrete when possible (e.g., mention skipped/cancelled item numbers). ' +
'Do not apologize, do not mention internal policy, and do not add extra steps.';
const STEERING_ACK_TIMEOUT_MS = 1200;
const STEERING_ACK_MAX_INPUT_CHARS = 320;
const STEERING_ACK_MAX_OUTPUT_CHARS = 90;
function buildSteeringFallbackMessage(hintText: string): string {
const normalized = normalizeSpace(hintText);
if (!normalized) {
return 'Understood. Adjusting the plan.';
}
if (safeLength(normalized) <= 64) {
return `Understood. ${normalized}`;
}
return `Understood. ${safeSlice(normalized, 0, 61)}...`;
}
export async function generateSteeringAckMessage(
llmClient: BaseLlmClient,
hintText: string,
): Promise<string> {
const fallbackText = buildSteeringFallbackMessage(hintText);
const abortController = new AbortController();
const timeout = setTimeout(
() => abortController.abort(),
STEERING_ACK_TIMEOUT_MS,
);
try {
return await generateFastAckText(llmClient, {
instruction: STEERING_ACK_INSTRUCTION,
input: normalizeSpace(hintText),
fallbackText,
abortSignal: abortController.signal,
maxInputChars: STEERING_ACK_MAX_INPUT_CHARS,
maxOutputChars: STEERING_ACK_MAX_OUTPUT_CHARS,
promptId: 'steering-ack',
});
} finally {
clearTimeout(timeout);
}
}
export interface GenerateFastAckTextOptions {
instruction: string;
input: string;
fallbackText: string;
abortSignal: AbortSignal;
promptId: string;
modelConfigKey?: ModelConfigKey;
maxInputChars?: number;
maxOutputChars?: number;
}
export function truncateFastAckInput(
input: string,
maxInputChars: number = DEFAULT_MAX_INPUT_CHARS,
): string {
const suffixLength = safeLength(INPUT_TRUNCATION_SUFFIX);
if (maxInputChars <= suffixLength) {
return safeSlice(input, 0, Math.max(maxInputChars, 0));
}
if (safeLength(input) <= maxInputChars) {
return input;
}
const keepChars = maxInputChars - suffixLength;
return safeSlice(input, 0, keepChars) + INPUT_TRUNCATION_SUFFIX;
}
export async function generateFastAckText(
llmClient: BaseLlmClient,
options: GenerateFastAckTextOptions,
): Promise<string> {
const {
instruction,
input,
fallbackText,
abortSignal,
promptId,
modelConfigKey = DEFAULT_FAST_ACK_MODEL_CONFIG_KEY,
maxInputChars = DEFAULT_MAX_INPUT_CHARS,
maxOutputChars = DEFAULT_MAX_OUTPUT_CHARS,
} = options;
const safeInstruction = instruction.trim();
if (!safeInstruction) {
return fallbackText;
}
const safeInput = truncateFastAckInput(input.trim(), maxInputChars);
const prompt = `${safeInstruction}\n\nUser input:\n${wrapInput(safeInput)}`;
try {
const response = await llmClient.generateContent({
modelConfigKey,
contents: [{ role: 'user', parts: [{ text: prompt }] }],
role: LlmRole.UTILITY_FAST_ACK_HELPER,
abortSignal,
promptId,
maxAttempts: 1, // Fast path, don't retry much
});
const responseText = normalizeSpace(getResponseText(response) || '');
if (!responseText) {
return fallbackText;
}
if (maxOutputChars > 0 && safeLength(responseText) > maxOutputChars) {
return safeSlice(responseText, 0, maxOutputChars).trimEnd();
}
return responseText;
} catch (error) {
debugLogger.debug(
`[FastAckHelper] Generation failed: ${error instanceof Error ? error.message : String(error)}`,
);
return fallbackText;
}
}