mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
feat(core): Fully migrate packages/core to AgentLoopContext. (#22115)
This commit is contained in:
@@ -172,6 +172,9 @@ describe('ChatCompressionService', () => {
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
getCompressionThreshold: vi.fn(),
|
||||
getBaseLlmClient: vi.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
|
||||
@@ -43,6 +43,13 @@ describe('ChatRecordingService', () => {
|
||||
);
|
||||
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
toolRegistry: {
|
||||
getTool: vi.fn(),
|
||||
},
|
||||
promptId: 'test-session-id',
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
storage: {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Config } from '../config/config.js';
|
||||
import { type Status } from '../core/coreToolScheduler.js';
|
||||
import { type ThoughtSummary } from '../utils/thoughtUtils.js';
|
||||
import { getProjectHash } from '../utils/paths.js';
|
||||
@@ -20,6 +19,7 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ToolResultDisplay } from '../tools/tools.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
|
||||
export const SESSION_FILE_PREFIX = 'session-';
|
||||
|
||||
@@ -134,12 +134,12 @@ export class ChatRecordingService {
|
||||
private kind?: 'main' | 'subagent';
|
||||
private queuedThoughts: Array<ThoughtSummary & { timestamp: string }> = [];
|
||||
private queuedTokens: TokensSummary | null = null;
|
||||
private config: Config;
|
||||
private context: AgentLoopContext;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
this.sessionId = config.getSessionId();
|
||||
this.projectHash = getProjectHash(config.getProjectRoot());
|
||||
constructor(context: AgentLoopContext) {
|
||||
this.context = context;
|
||||
this.sessionId = context.promptId;
|
||||
this.projectHash = getProjectHash(context.config.getProjectRoot());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -171,9 +171,9 @@ export class ChatRecordingService {
|
||||
this.cachedConversation = null;
|
||||
} else {
|
||||
// Create new session
|
||||
this.sessionId = this.config.getSessionId();
|
||||
this.sessionId = this.context.promptId;
|
||||
const chatsDir = path.join(
|
||||
this.config.storage.getProjectTempDir(),
|
||||
this.context.config.storage.getProjectTempDir(),
|
||||
'chats',
|
||||
);
|
||||
fs.mkdirSync(chatsDir, { recursive: true });
|
||||
@@ -341,7 +341,7 @@ export class ChatRecordingService {
|
||||
if (!this.conversationFile) return;
|
||||
|
||||
// Enrich tool calls with metadata from the ToolRegistry
|
||||
const toolRegistry = this.config.getToolRegistry();
|
||||
const toolRegistry = this.context.toolRegistry;
|
||||
const enrichedToolCalls = toolCalls.map((toolCall) => {
|
||||
const toolInstance = toolRegistry.getTool(toolCall.name);
|
||||
return {
|
||||
@@ -594,7 +594,7 @@ export class ChatRecordingService {
|
||||
*/
|
||||
deleteSession(sessionId: string): void {
|
||||
try {
|
||||
const tempDir = this.config.storage.getProjectTempDir();
|
||||
const tempDir = this.context.config.storage.getProjectTempDir();
|
||||
const chatsDir = path.join(tempDir, 'chats');
|
||||
const sessionPath = path.join(chatsDir, `${sessionId}.json`);
|
||||
if (fs.existsSync(sessionPath)) {
|
||||
|
||||
@@ -36,6 +36,9 @@ describe('LoopDetectionService', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
getTelemetryEnabled: () => true,
|
||||
isInteractive: () => false,
|
||||
getDisableLoopDetection: () => false,
|
||||
@@ -806,7 +809,13 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
vi.mocked(mockAvailability.snapshot).mockReturnValue({ available: true });
|
||||
|
||||
mockConfig = {
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
get geminiClient() {
|
||||
return mockGeminiClient;
|
||||
},
|
||||
getBaseLlmClient: () => mockBaseLlmClient,
|
||||
getDisableLoopDetection: () => false,
|
||||
getDebugMode: () => false,
|
||||
|
||||
@@ -19,12 +19,12 @@ import {
|
||||
LlmLoopCheckEvent,
|
||||
LlmRole,
|
||||
} from '../telemetry/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
isFunctionResponse,
|
||||
} from '../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 5;
|
||||
const CONTENT_LOOP_THRESHOLD = 10;
|
||||
@@ -131,7 +131,7 @@ export interface LoopDetectionResult {
|
||||
* Monitors tool call repetitions and content sentence repetitions.
|
||||
*/
|
||||
export class LoopDetectionService {
|
||||
private readonly config: Config;
|
||||
private readonly context: AgentLoopContext;
|
||||
private promptId = '';
|
||||
private userPrompt = '';
|
||||
|
||||
@@ -157,8 +157,8 @@ export class LoopDetectionService {
|
||||
// Session-level disable flag
|
||||
private disabledForSession = false;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
constructor(context: AgentLoopContext) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -167,7 +167,7 @@ export class LoopDetectionService {
|
||||
disableForSession(): void {
|
||||
this.disabledForSession = true;
|
||||
logLoopDetectionDisabled(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new LoopDetectionDisabledEvent(this.promptId),
|
||||
);
|
||||
}
|
||||
@@ -184,7 +184,10 @@ export class LoopDetectionService {
|
||||
* @returns A LoopDetectionResult
|
||||
*/
|
||||
addAndCheck(event: ServerGeminiStreamEvent): LoopDetectionResult {
|
||||
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
|
||||
if (
|
||||
this.disabledForSession ||
|
||||
this.context.config.getDisableLoopDetection()
|
||||
) {
|
||||
return { count: 0 };
|
||||
}
|
||||
if (this.loopDetected) {
|
||||
@@ -228,7 +231,7 @@ export class LoopDetectionService {
|
||||
: LoopType.CONTENT_CHANTING_LOOP;
|
||||
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new LoopDetectedEvent(
|
||||
this.lastLoopType,
|
||||
this.promptId,
|
||||
@@ -256,7 +259,10 @@ export class LoopDetectionService {
|
||||
* @returns A promise that resolves to a LoopDetectionResult.
|
||||
*/
|
||||
async turnStarted(signal: AbortSignal): Promise<LoopDetectionResult> {
|
||||
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
|
||||
if (
|
||||
this.disabledForSession ||
|
||||
this.context.config.getDisableLoopDetection()
|
||||
) {
|
||||
return { count: 0 };
|
||||
}
|
||||
if (this.loopDetected) {
|
||||
@@ -283,7 +289,7 @@ export class LoopDetectionService {
|
||||
this.lastLoopType = LoopType.LLM_DETECTED_LOOP;
|
||||
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new LoopDetectedEvent(
|
||||
this.lastLoopType,
|
||||
this.promptId,
|
||||
@@ -536,8 +542,7 @@ export class LoopDetectionService {
|
||||
analysis?: string;
|
||||
confirmedByModel?: string;
|
||||
}> {
|
||||
const recentHistory = this.config
|
||||
.getGeminiClient()
|
||||
const recentHistory = this.context.geminiClient
|
||||
.getHistory()
|
||||
.slice(-LLM_LOOP_CHECK_HISTORY_COUNT);
|
||||
|
||||
@@ -590,13 +595,13 @@ export class LoopDetectionService {
|
||||
: '';
|
||||
|
||||
const doubleCheckModelName =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
this.context.config.modelConfigService.getResolvedConfig({
|
||||
model: DOUBLE_CHECK_MODEL_ALIAS,
|
||||
}).model;
|
||||
|
||||
if (flashConfidence < LLM_CONFIDENCE_THRESHOLD) {
|
||||
logLlmLoopCheck(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new LlmLoopCheckEvent(
|
||||
this.promptId,
|
||||
flashConfidence,
|
||||
@@ -608,12 +613,13 @@ export class LoopDetectionService {
|
||||
return { isLoop: false };
|
||||
}
|
||||
|
||||
const availability = this.config.getModelAvailabilityService();
|
||||
const availability = this.context.config.getModelAvailabilityService();
|
||||
|
||||
if (!availability.snapshot(doubleCheckModelName).available) {
|
||||
const flashModelName = this.config.modelConfigService.getResolvedConfig({
|
||||
model: 'loop-detection',
|
||||
}).model;
|
||||
const flashModelName =
|
||||
this.context.config.modelConfigService.getResolvedConfig({
|
||||
model: 'loop-detection',
|
||||
}).model;
|
||||
return {
|
||||
isLoop: true,
|
||||
analysis: flashAnalysis,
|
||||
@@ -642,7 +648,7 @@ export class LoopDetectionService {
|
||||
: undefined;
|
||||
|
||||
logLlmLoopCheck(
|
||||
this.config,
|
||||
this.context.config,
|
||||
new LlmLoopCheckEvent(
|
||||
this.promptId,
|
||||
flashConfidence,
|
||||
@@ -672,7 +678,7 @@ export class LoopDetectionService {
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown> | null> {
|
||||
try {
|
||||
const result = await this.config.getBaseLlmClient().generateJson({
|
||||
const result = await this.context.config.getBaseLlmClient().generateJson({
|
||||
modelConfigKey: { model },
|
||||
contents,
|
||||
schema: LOOP_DETECTION_SCHEMA,
|
||||
@@ -692,7 +698,7 @@ export class LoopDetectionService {
|
||||
}
|
||||
return null;
|
||||
} catch (error) {
|
||||
if (this.config.getDebugMode()) {
|
||||
if (this.context.config.getDebugMode()) {
|
||||
debugLogger.warn(
|
||||
`Error querying loop detection model (${model}): ${String(error)}`,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user