mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-19 09:41:17 -07:00
fix(core): resolve subagent chat recording gaps and directory inheritance (#24368)
This commit is contained in:
@@ -18,6 +18,8 @@ const {
|
||||
mockSendMessageStream,
|
||||
mockScheduleAgentTools,
|
||||
mockSetSystemInstruction,
|
||||
mockRecordCompletedToolCalls,
|
||||
mockSaveSummary,
|
||||
mockCompress,
|
||||
mockMaybeDiscoverMcpServer,
|
||||
mockStopMcp,
|
||||
@@ -32,6 +34,8 @@ const {
|
||||
}),
|
||||
mockScheduleAgentTools: vi.fn(),
|
||||
mockSetSystemInstruction: vi.fn(),
|
||||
mockRecordCompletedToolCalls: vi.fn(),
|
||||
mockSaveSummary: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
mockMaybeDiscoverMcpServer: vi.fn().mockResolvedValue(undefined),
|
||||
mockStopMcp: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -127,18 +131,21 @@ vi.mock('../context/chatCompressionService.js', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../core/geminiChat.js')>();
|
||||
return {
|
||||
...actual,
|
||||
GeminiChat: vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
})),
|
||||
};
|
||||
});
|
||||
vi.mock('../core/geminiChat.js', () => ({
|
||||
StreamEventType: {
|
||||
CHUNK: 'chunk',
|
||||
},
|
||||
GeminiChat: vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
recordCompletedToolCalls: mockRecordCompletedToolCalls,
|
||||
getChatRecordingService: vi.fn().mockReturnValue({
|
||||
saveSummary: mockSaveSummary,
|
||||
}),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./agent-scheduler.js', () => ({
|
||||
scheduleAgentTools: mockScheduleAgentTools,
|
||||
@@ -337,6 +344,10 @@ describe('LocalAgentExecutor', () => {
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
recordCompletedToolCalls: mockRecordCompletedToolCalls,
|
||||
getChatRecordingService: vi.fn().mockReturnValue({
|
||||
saveSummary: mockSaveSummary,
|
||||
}),
|
||||
}) as unknown as GeminiChat,
|
||||
);
|
||||
|
||||
@@ -942,6 +953,20 @@ describe('LocalAgentExecutor', () => {
|
||||
|
||||
// Context checks
|
||||
expect(mockedPromptIdContext.run).toHaveBeenCalledTimes(2); // Two turns
|
||||
|
||||
// Recording checks
|
||||
expect(mockRecordCompletedToolCalls).toHaveBeenCalledTimes(1);
|
||||
expect(mockRecordCompletedToolCalls).toHaveBeenCalledWith(
|
||||
expect.any(String), // model
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
status: 'success',
|
||||
request: expect.objectContaining({ name: LS_TOOL_NAME }),
|
||||
}),
|
||||
]),
|
||||
);
|
||||
expect(mockSaveSummary).toHaveBeenCalledTimes(1);
|
||||
expect(mockSaveSummary).toHaveBeenCalledWith('Found file1.txt');
|
||||
const agentId = executor['agentId'];
|
||||
expect(mockedPromptIdContext.run).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
@@ -2450,6 +2475,10 @@ describe('LocalAgentExecutor', () => {
|
||||
expect(recoveryEvent).toBeInstanceOf(RecoveryAttemptEvent);
|
||||
expect(recoveryEvent.success).toBe(true);
|
||||
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
|
||||
// Verify that the summary is saved upon successful recovery
|
||||
expect(mockSaveSummary).toHaveBeenCalledTimes(1);
|
||||
expect(mockSaveSummary).toHaveBeenCalledWith('Recovered!');
|
||||
});
|
||||
|
||||
describe('Model Steering', () => {
|
||||
|
||||
@@ -317,8 +317,10 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
await this.tryCompressChat(chat, promptId, combinedSignal);
|
||||
|
||||
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
||||
this.callModel(chat, currentMessage, combinedSignal, promptId),
|
||||
const { functionCalls, modelToUse } = await promptIdContext.run(
|
||||
promptId,
|
||||
async () =>
|
||||
this.callModel(chat, currentMessage, combinedSignal, promptId),
|
||||
);
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
@@ -348,6 +350,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
const { nextMessage, submittedOutput, taskCompleted, aborted } =
|
||||
await this.processFunctionCalls(
|
||||
chat,
|
||||
modelToUse,
|
||||
functionCalls,
|
||||
combinedSignal,
|
||||
promptId,
|
||||
@@ -722,8 +726,17 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
}
|
||||
|
||||
// === FINAL RETURN LOGIC ===
|
||||
if (terminateReason === AgentTerminateMode.GOAL) {
|
||||
// Save the session summary upon completion
|
||||
if (finalResult && chat) {
|
||||
try {
|
||||
const summary = this.getTruncatedSummary(finalResult);
|
||||
chat.getChatRecordingService()?.saveSummary(summary);
|
||||
} catch (error) {
|
||||
debugLogger.warn('Failed to save subagent session summary.', error);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
result: finalResult || 'Task completed.',
|
||||
terminate_reason: terminateReason,
|
||||
@@ -759,6 +772,18 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// Recovery Succeeded
|
||||
terminateReason = AgentTerminateMode.GOAL;
|
||||
finalResult = recoveryResult;
|
||||
|
||||
// Save the session summary upon successful recovery
|
||||
try {
|
||||
const summary = this.getTruncatedSummary(finalResult);
|
||||
chat.getChatRecordingService()?.saveSummary(summary);
|
||||
} catch (summaryError) {
|
||||
debugLogger.warn(
|
||||
'Failed to save subagent session summary during recovery.',
|
||||
summaryError,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
result: finalResult,
|
||||
terminate_reason: terminateReason,
|
||||
@@ -846,7 +871,11 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
message: Content,
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
||||
): Promise<{
|
||||
functionCalls: FunctionCall[];
|
||||
textResponse: string;
|
||||
modelToUse: string;
|
||||
}> {
|
||||
const modelConfigAlias = getModelConfigAlias(this.definition);
|
||||
|
||||
// Resolve the model config early to get the concrete model string (which may be `auto`).
|
||||
@@ -931,7 +960,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
}
|
||||
|
||||
return { functionCalls, textResponse };
|
||||
return { functionCalls, textResponse, modelToUse };
|
||||
}
|
||||
|
||||
/** Initializes a `GeminiChat` instance for the agent run. */
|
||||
@@ -985,6 +1014,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
* @returns A new `Content` object for history, any submitted output, and completion status.
|
||||
*/
|
||||
private async processFunctionCalls(
|
||||
chat: GeminiChat,
|
||||
model: string,
|
||||
functionCalls: FunctionCall[],
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
@@ -1226,6 +1257,9 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
},
|
||||
);
|
||||
|
||||
// Record completed tool calls for persistent chat history
|
||||
chat.recordCompletedToolCalls(model, completedCalls);
|
||||
|
||||
for (const call of completedCalls) {
|
||||
const toolName =
|
||||
toolNameMap.get(call.request.callId) || call.request.name;
|
||||
@@ -1475,4 +1509,15 @@ Important Rules:
|
||||
this.onActivity(event);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates a string to 200 characters in a Unicode-safe way for session summaries.
|
||||
*/
|
||||
private getTruncatedSummary(text: string): string {
|
||||
const chars = Array.from(text);
|
||||
if (chars.length <= 200) {
|
||||
return text;
|
||||
}
|
||||
return chars.slice(0, 197).join('') + '...';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
type ToolCallRecord,
|
||||
type MessageRecord,
|
||||
} from './chatRecordingService.js';
|
||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { CoreToolCallStatus } from '../scheduler/types.js';
|
||||
import type { Content, Part } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -57,6 +58,9 @@ describe('ChatRecordingService', () => {
|
||||
},
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
getWorkspaceContext: vi.fn().mockReturnValue({
|
||||
getDirectories: vi.fn().mockReturnValue([]),
|
||||
}),
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn().mockReturnValue({
|
||||
displayName: 'Test Tool',
|
||||
@@ -66,6 +70,13 @@ describe('ChatRecordingService', () => {
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
// Ensure mockConfig.config points to itself for AgentLoopContext parity
|
||||
Object.defineProperty(mockConfig, 'config', {
|
||||
get() {
|
||||
return mockConfig;
|
||||
},
|
||||
});
|
||||
|
||||
vi.mocked(getProjectHash).mockReturnValue('test-project-hash');
|
||||
chatRecordingService = new ChatRecordingService(mockConfig);
|
||||
});
|
||||
@@ -132,6 +143,31 @@ describe('ChatRecordingService', () => {
|
||||
expect(files[0]).toBe('test-session-id.json');
|
||||
});
|
||||
|
||||
it('should inherit workspace directories for subagents during initialization', () => {
|
||||
const mockDirectories = ['/project/dir1', '/project/dir2'];
|
||||
vi.mocked(mockConfig.getWorkspaceContext).mockReturnValue({
|
||||
getDirectories: vi.fn().mockReturnValue(mockDirectories),
|
||||
} as unknown as WorkspaceContext);
|
||||
|
||||
// Initialize as a subagent
|
||||
chatRecordingService.initialize(undefined, 'subagent');
|
||||
|
||||
// Recording a message triggers the disk write (deferred until then)
|
||||
chatRecordingService.recordMessage({
|
||||
type: 'user',
|
||||
content: 'ping',
|
||||
model: 'm',
|
||||
});
|
||||
|
||||
const sessionFile = chatRecordingService.getConversationFilePath()!;
|
||||
const conversation = JSON.parse(
|
||||
fs.readFileSync(sessionFile, 'utf8'),
|
||||
) as ConversationRecord;
|
||||
|
||||
expect(conversation.kind).toBe('subagent');
|
||||
expect(conversation.directories).toEqual(mockDirectories);
|
||||
});
|
||||
|
||||
it('should resume from an existing session if provided', () => {
|
||||
const chatsDir = path.join(testTempDir, 'chats');
|
||||
fs.mkdirSync(chatsDir, { recursive: true });
|
||||
|
||||
@@ -218,12 +218,22 @@ export class ChatRecordingService {
|
||||
}
|
||||
this.conversationFile = path.join(chatsDir, filename);
|
||||
|
||||
const directories =
|
||||
this.kind === 'subagent'
|
||||
? [
|
||||
...(this.context.config
|
||||
.getWorkspaceContext()
|
||||
?.getDirectories() ?? []),
|
||||
]
|
||||
: undefined;
|
||||
|
||||
this.writeConversation({
|
||||
sessionId: this.sessionId,
|
||||
projectHash: this.projectHash,
|
||||
startTime: new Date().toISOString(),
|
||||
lastUpdated: new Date().toISOString(),
|
||||
messages: [],
|
||||
directories,
|
||||
kind: this.kind,
|
||||
});
|
||||
}
|
||||
@@ -518,6 +528,13 @@ export class ChatRecordingService {
|
||||
): void {
|
||||
try {
|
||||
if (!this.conversationFile) return;
|
||||
|
||||
// Cache the conversation state even if we don't write to disk yet.
|
||||
// This ensures that subsequent reads (e.g. during recordMessage)
|
||||
// see the initial state (like directories) instead of trying to
|
||||
// read a non-existent file from disk.
|
||||
this.cachedConversation = conversation;
|
||||
|
||||
// Don't write the file yet until there's at least one message.
|
||||
if (conversation.messages.length === 0 && !allowEmpty) return;
|
||||
|
||||
@@ -527,7 +544,6 @@ export class ChatRecordingService {
|
||||
// Compare before updating lastUpdated so the timestamp doesn't
|
||||
// cause a false diff.
|
||||
if (this.cachedLastConvData === newContent) return;
|
||||
this.cachedConversation = conversation;
|
||||
conversation.lastUpdated = new Date().toISOString();
|
||||
const contentToWrite = JSON.stringify(conversation, null, 2);
|
||||
this.cachedLastConvData = contentToWrite;
|
||||
|
||||
Reference in New Issue
Block a user