Core data structure updates for Rewind functionality (#15714)

This commit is contained in:
Adib234
2026-01-07 12:10:22 -05:00
committed by GitHub
parent db99beda36
commit 57012ae5b3
12 changed files with 145 additions and 5 deletions
@@ -182,6 +182,7 @@ describe('<ToolMessage />', () => {
fileName: 'file.txt', fileName: 'file.txt',
originalContent: 'old', originalContent: 'old',
newContent: 'new', newContent: 'new',
filePath: 'file.txt',
}; };
const { lastFrame } = renderWithContext( const { lastFrame } = renderWithContext(
<ToolMessage {...baseProps} resultDisplay={diffResult} />, <ToolMessage {...baseProps} resultDisplay={diffResult} />,
@@ -64,6 +64,7 @@ exports[`useReactToolScheduler > should handle tool requiring confirmation - can
"resultDisplay": { "resultDisplay": {
"fileDiff": "Mock tool requires confirmation", "fileDiff": "Mock tool requires confirmation",
"fileName": "mockToolRequiresConfirmation.ts", "fileName": "mockToolRequiresConfirmation.ts",
"filePath": undefined,
"newContent": undefined, "newContent": undefined,
"originalContent": undefined, "originalContent": undefined,
}, },
+33
View File
@@ -38,6 +38,7 @@ import { ideContextStore } from '../ide/ideContext.js';
import type { ModelRouterService } from '../routing/modelRouterService.js'; import type { ModelRouterService } from '../routing/modelRouterService.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { ChatCompressionService } from '../services/chatCompressionService.js'; import { ChatCompressionService } from '../services/chatCompressionService.js';
import type { ChatRecordingService } from '../services/chatRecordingService.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import type { import type {
@@ -397,6 +398,10 @@ describe('Gemini Client (client.ts)', () => {
getHistory: vi.fn((_curated?: boolean) => chatHistory), getHistory: vi.fn((_curated?: boolean) => chatHistory),
setHistory: vi.fn(), setHistory: vi.fn(),
getLastPromptTokenCount: vi.fn().mockReturnValue(originalTokenCount), getLastPromptTokenCount: vi.fn().mockReturnValue(originalTokenCount),
getChatRecordingService: vi.fn().mockReturnValue({
getConversation: vi.fn().mockReturnValue(null),
getConversationFilePath: vi.fn().mockReturnValue(null),
}),
}; };
client['chat'] = mockOriginalChat as GeminiChat; client['chat'] = mockOriginalChat as GeminiChat;
@@ -617,6 +622,34 @@ describe('Gemini Client (client.ts)', () => {
newTokenCount: 50, newTokenCount: 50,
}); });
}); });
it('should resume the session file when compression succeeds', async () => {
const { client, mockOriginalChat } = setup({
compressionStatus: CompressionStatus.COMPRESSED,
});
const mockConversation = { some: 'conversation' };
const mockFilePath = '/tmp/session.json';
// Override the mock to return values
const mockRecordingService = {
getConversation: vi.fn().mockReturnValue(mockConversation),
getConversationFilePath: vi.fn().mockReturnValue(mockFilePath),
};
vi.mocked(mockOriginalChat.getChatRecordingService!).mockReturnValue(
mockRecordingService as unknown as ChatRecordingService,
);
await client.tryCompressChat('prompt-id', false);
expect(client['startChat']).toHaveBeenCalledWith(
expect.anything(), // newHistory
{
conversation: mockConversation,
filePath: mockFilePath,
},
);
});
}); });
describe('sendMessageStream', () => { describe('sendMessageStream', () => {
+13 -1
View File
@@ -972,7 +972,19 @@ export class GeminiClient {
this.hasFailedCompressionAttempt || !force; this.hasFailedCompressionAttempt || !force;
} else if (info.compressionStatus === CompressionStatus.COMPRESSED) { } else if (info.compressionStatus === CompressionStatus.COMPRESSED) {
if (newHistory) { if (newHistory) {
this.chat = await this.startChat(newHistory); // capture current session data before resetting
const currentRecordingService =
this.getChat().getChatRecordingService();
const conversation = currentRecordingService.getConversation();
const filePath = currentRecordingService.getConversationFilePath();
let resumedData: ResumedSessionData | undefined;
if (conversation && filePath) {
resumedData = { conversation, filePath };
}
this.chat = await this.startChat(newHistory, resumedData);
this.updateTelemetryTokenCount(); this.updateTelemetryTokenCount();
this.forceFullIdeContext = true; this.forceFullIdeContext = true;
} }
@@ -279,6 +279,7 @@ export class CoreToolScheduler {
originalContent: originalContent:
waitingCall.confirmationDetails.originalContent, waitingCall.confirmationDetails.originalContent,
newContent: waitingCall.confirmationDetails.newContent, newContent: waitingCall.confirmationDetails.newContent,
filePath: waitingCall.confirmationDetails.filePath,
}; };
} }
} }
+4 -1
View File
@@ -838,7 +838,10 @@ export class GeminiChat {
const toolCallRecords = toolCalls.map((call) => { const toolCallRecords = toolCalls.map((call) => {
const resultDisplayRaw = call.response?.resultDisplay; const resultDisplayRaw = call.response?.resultDisplay;
const resultDisplay = const resultDisplay =
typeof resultDisplayRaw === 'string' ? resultDisplayRaw : undefined; typeof resultDisplayRaw === 'string' ||
(typeof resultDisplayRaw === 'object' && resultDisplayRaw !== null)
? resultDisplayRaw
: undefined;
return { return {
id: call.request.callId, id: call.request.callId,
@@ -401,4 +401,57 @@ describe('ChatRecordingService', () => {
); );
}); });
}); });
describe('rewindTo', () => {
it('should rewind the conversation to a specific message ID', () => {
chatRecordingService.initialize();
const initialConversation = {
sessionId: 'test-session-id',
projectHash: 'test-project-hash',
messages: [
{ id: '1', type: 'user', content: 'msg1' },
{ id: '2', type: 'gemini', content: 'msg2' },
{ id: '3', type: 'user', content: 'msg3' },
],
};
vi.spyOn(fs, 'readFileSync').mockReturnValue(
JSON.stringify(initialConversation),
);
const writeFileSyncSpy = vi
.spyOn(fs, 'writeFileSync')
.mockImplementation(() => undefined);
const result = chatRecordingService.rewindTo('2');
if (!result) throw new Error('Result should not be null');
expect(result.messages).toHaveLength(1);
expect(result.messages[0].id).toBe('1');
expect(writeFileSyncSpy).toHaveBeenCalled();
const savedConversation = JSON.parse(
writeFileSyncSpy.mock.calls[0][1] as string,
) as ConversationRecord;
expect(savedConversation.messages).toHaveLength(1);
});
it('should return the original conversation if the message ID is not found', () => {
chatRecordingService.initialize();
const initialConversation = {
sessionId: 'test-session-id',
projectHash: 'test-project-hash',
messages: [{ id: '1', type: 'user', content: 'msg1' }],
};
vi.spyOn(fs, 'readFileSync').mockReturnValue(
JSON.stringify(initialConversation),
);
const writeFileSyncSpy = vi
.spyOn(fs, 'writeFileSync')
.mockImplementation(() => undefined);
const result = chatRecordingService.rewindTo('non-existent');
if (!result) throw new Error('Result should not be null');
expect(result.messages).toHaveLength(1);
expect(writeFileSyncSpy).not.toHaveBeenCalled();
});
});
}); });
@@ -16,6 +16,7 @@ import type {
GenerateContentResponseUsageMetadata, GenerateContentResponseUsageMetadata,
} from '@google/genai'; } from '@google/genai';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import type { ToolResultDisplay } from '../tools/tools.js';
export const SESSION_FILE_PREFIX = 'session-'; export const SESSION_FILE_PREFIX = 'session-';
@@ -53,7 +54,7 @@ export interface ToolCallRecord {
// UI-specific fields for display purposes // UI-specific fields for display purposes
displayName?: string; displayName?: string;
description?: string; description?: string;
resultDisplay?: string; resultDisplay?: ToolResultDisplay;
renderOutputAsMarkdown?: boolean; renderOutputAsMarkdown?: boolean;
} }
@@ -407,11 +408,14 @@ export class ChatRecordingService {
/** /**
* Saves the conversation record; overwrites the file. * Saves the conversation record; overwrites the file.
*/ */
private writeConversation(conversation: ConversationRecord): void { private writeConversation(
conversation: ConversationRecord,
{ allowEmpty = false }: { allowEmpty?: boolean } = {},
): void {
try { try {
if (!this.conversationFile) return; if (!this.conversationFile) return;
// Don't write the file yet until there's at least one message. // Don't write the file yet until there's at least one message.
if (conversation.messages.length === 0) return; if (conversation.messages.length === 0 && !allowEmpty) return;
// Only write the file if this change would change the file. // Only write the file if this change would change the file.
if (this.cachedLastConvData !== JSON.stringify(conversation, null, 2)) { if (this.cachedLastConvData !== JSON.stringify(conversation, null, 2)) {
@@ -492,4 +496,29 @@ export class ChatRecordingService {
throw error; throw error;
} }
} }
/**
* Rewinds the conversation to the state just before the specified message ID.
* All messages from (and including) the specified ID onwards are removed.
*/
rewindTo(messageId: string): ConversationRecord | null {
if (!this.conversationFile) {
return null;
}
const conversation = this.readConversation();
const messageIndex = conversation.messages.findIndex(
(m) => m.id === messageId,
);
if (messageIndex === -1) {
debugLogger.error(
'Message to rewind to not found in conversation history',
);
return conversation;
}
conversation.messages = conversation.messages.slice(0, messageIndex);
this.writeConversation(conversation, { allowEmpty: true });
return conversation;
}
} }
@@ -1053,6 +1053,7 @@ describe('loggers', () => {
resultDisplay: { resultDisplay: {
fileDiff: 'diff', fileDiff: 'diff',
fileName: 'file.txt', fileName: 'file.txt',
filePath: 'file.txt',
originalContent: 'old content', originalContent: 'old content',
newContent: 'new content', newContent: 'new content',
diffStat: { diffStat: {
+2
View File
@@ -818,9 +818,11 @@ class EditToolInvocation
displayResult = { displayResult = {
fileDiff, fileDiff,
fileName, fileName,
filePath: this.params.file_path,
originalContent: editData.currentContent, originalContent: editData.currentContent,
newContent: editData.newContent, newContent: editData.newContent,
diffStat, diffStat,
isNewFile: editData.isNewFile,
}; };
} }
+2
View File
@@ -647,9 +647,11 @@ export interface Todo {
export interface FileDiff { export interface FileDiff {
fileDiff: string; fileDiff: string;
fileName: string; fileName: string;
filePath: string;
originalContent: string | null; originalContent: string | null;
newContent: string; newContent: string;
diffStat?: DiffStat; diffStat?: DiffStat;
isNewFile?: boolean;
} }
export interface DiffStat { export interface DiffStat {
+2
View File
@@ -346,9 +346,11 @@ class WriteFileToolInvocation extends BaseToolInvocation<
const displayResult: FileDiff = { const displayResult: FileDiff = {
fileDiff, fileDiff,
fileName, fileName,
filePath: this.resolvedPath,
originalContent: correctedContentResult.originalContent, originalContent: correctedContentResult.originalContent,
newContent: correctedContentResult.correctedContent, newContent: correctedContentResult.correctedContent,
diffStat, diffStat,
isNewFile,
}; };
return { return {