mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-19 02:20:42 -07:00
feat: Allow cancellation of in-progress Gemini requests and pre-execution checks
- Implements cancellation for Gemini requests while they are actively being processed by the model. - Extends cancellation support to the logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections. - Underlying LLM calls for edit corrections (within and ) and next speaker checks can now also be cancelled. - Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable. - This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle. Fixes https://github.com/google-gemini/gemini-cli/issues/531
This commit is contained in:
committed by
N. Taylor Mullen
parent
bfeaac8441
commit
f2f2ecf9d8
@@ -223,7 +223,9 @@ describe('EditTool', () => {
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should request confirmation for valid edit', async () => {
|
||||
@@ -235,7 +237,10 @@ describe('EditTool', () => {
|
||||
};
|
||||
// ensureCorrectEdit will be called by shouldConfirmExecute
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
||||
const confirmation = await tool.shouldConfirmExecute(params);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${testFile}`,
|
||||
@@ -253,7 +258,9 @@ describe('EditTool', () => {
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
|
||||
@@ -264,7 +271,9 @@ describe('EditTool', () => {
|
||||
new_string: 'new',
|
||||
};
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
|
||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
||||
expect(
|
||||
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should request confirmation for creating a new file (empty old_string)', async () => {
|
||||
@@ -279,7 +288,10 @@ describe('EditTool', () => {
|
||||
// as shouldConfirmExecute handles this for diff generation.
|
||||
// If it is called, it should return 0 occurrences for a new file.
|
||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||
const confirmation = await tool.shouldConfirmExecute(params);
|
||||
const confirmation = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
title: `Confirm Edit: ${newFileName}`,
|
||||
@@ -328,6 +340,7 @@ describe('EditTool', () => {
|
||||
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
new AbortController().signal,
|
||||
)) as FileDiff;
|
||||
|
||||
expect(mockCalled).toBe(true); // Check if the mock implementation was run
|
||||
|
||||
@@ -174,7 +174,10 @@ Expectation for parameters:
|
||||
* @returns An object describing the potential edit outcome
|
||||
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
|
||||
*/
|
||||
private async calculateEdit(params: EditToolParams): Promise<CalculatedEdit> {
|
||||
private async calculateEdit(
|
||||
params: EditToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CalculatedEdit> {
|
||||
const expectedReplacements = 1;
|
||||
let currentContent: string | null = null;
|
||||
let fileExists = false;
|
||||
@@ -210,6 +213,7 @@ Expectation for parameters:
|
||||
currentContent,
|
||||
params,
|
||||
this.client,
|
||||
abortSignal,
|
||||
);
|
||||
finalOldString = correctedEdit.params.old_string;
|
||||
finalNewString = correctedEdit.params.new_string;
|
||||
@@ -262,6 +266,7 @@ Expectation for parameters:
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
params: EditToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
||||
return false;
|
||||
@@ -300,6 +305,7 @@ Expectation for parameters:
|
||||
currentContent,
|
||||
params,
|
||||
this.client,
|
||||
abortSignal,
|
||||
);
|
||||
finalOldString = correctedEdit.params.old_string;
|
||||
finalNewString = correctedEdit.params.new_string;
|
||||
@@ -376,7 +382,7 @@ Expectation for parameters:
|
||||
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(params);
|
||||
editData = await this.calculateEdit(params, _signal);
|
||||
} catch (error) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
|
||||
@@ -98,6 +98,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||
|
||||
async shouldConfirmExecute(
|
||||
params: ShellToolParams,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.validateToolParams(params)) {
|
||||
return false; // skip confirmation, execute call will fail immediately
|
||||
|
||||
@@ -57,6 +57,7 @@ export interface Tool<
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
params: TParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
|
||||
/**
|
||||
@@ -137,6 +138,8 @@ export abstract class BaseTool<
|
||||
shouldConfirmExecute(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
params: TParams,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
@@ -110,18 +110,32 @@ describe('WriteFileTool', () => {
|
||||
// Default mock implementations that return valid structures
|
||||
mockEnsureCorrectEdit.mockImplementation(
|
||||
async (
|
||||
currentContent: string,
|
||||
_currentContent: string,
|
||||
params: EditToolParams,
|
||||
_client: GeminiClient,
|
||||
): Promise<CorrectedEditResult> =>
|
||||
Promise.resolve({
|
||||
signal?: AbortSignal, // Make AbortSignal optional to match usage
|
||||
): Promise<CorrectedEditResult> => {
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve({
|
||||
params: { ...params, new_string: params.new_string ?? '' },
|
||||
occurrences: 1,
|
||||
}),
|
||||
});
|
||||
},
|
||||
);
|
||||
mockEnsureCorrectFileContent.mockImplementation(
|
||||
async (content: string, _client: GeminiClient): Promise<string> =>
|
||||
Promise.resolve(content ?? ''),
|
||||
async (
|
||||
content: string,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> => {
|
||||
// Make AbortSignal optional
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
return Promise.resolve(content ?? '');
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -181,6 +195,7 @@ describe('WriteFileTool', () => {
|
||||
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
||||
const proposedContent = 'Proposed new content.';
|
||||
const correctedContent = 'Corrected new content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||
|
||||
@@ -188,11 +203,13 @@ describe('WriteFileTool', () => {
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedContent);
|
||||
@@ -206,6 +223,7 @@ describe('WriteFileTool', () => {
|
||||
const originalContent = 'Original existing content.';
|
||||
const proposedContent = 'Proposed replacement content.';
|
||||
const correctedProposedContent = 'Corrected replacement content.';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
// Ensure this mock is active and returns the correct structure
|
||||
@@ -222,6 +240,7 @@ describe('WriteFileTool', () => {
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
@@ -232,6 +251,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
expect(result.correctedContent).toBe(correctedProposedContent);
|
||||
@@ -243,6 +263,7 @@ describe('WriteFileTool', () => {
|
||||
it('should return error if reading an existing file fails (e.g. permissions)', async () => {
|
||||
const filePath = path.join(rootDir, 'unreadable_file.txt');
|
||||
const proposedContent = 'some content';
|
||||
const abortSignal = new AbortController().signal;
|
||||
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
||||
|
||||
const readError = new Error('Permission denied');
|
||||
@@ -255,6 +276,7 @@ describe('WriteFileTool', () => {
|
||||
const result = await tool._getCorrectedFileContent(
|
||||
filePath,
|
||||
proposedContent,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
|
||||
@@ -274,16 +296,17 @@ describe('WriteFileTool', () => {
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return false if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params);
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if params are invalid (outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const confirmation = await tool.shouldConfirmExecute(params);
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
@@ -298,7 +321,7 @@ describe('WriteFileTool', () => {
|
||||
throw readError;
|
||||
});
|
||||
|
||||
const confirmation = await tool.shouldConfirmExecute(params);
|
||||
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||
expect(confirmation).toBe(false);
|
||||
|
||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||
@@ -314,11 +337,13 @@ describe('WriteFileTool', () => {
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
@@ -343,7 +368,6 @@ describe('WriteFileTool', () => {
|
||||
'Corrected replacement for confirmation.';
|
||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||
|
||||
// Ensure this mock is active and returns the correct structure
|
||||
mockEnsureCorrectEdit.mockResolvedValue({
|
||||
params: {
|
||||
file_path: filePath,
|
||||
@@ -356,6 +380,7 @@ describe('WriteFileTool', () => {
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
const confirmation = (await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
)) as ToolEditConfirmationDetails;
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
@@ -366,6 +391,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
expect.objectContaining({
|
||||
@@ -381,9 +407,10 @@ describe('WriteFileTool', () => {
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
const abortSignal = new AbortController().signal;
|
||||
it('should return error if params are invalid (relative path)', async () => {
|
||||
const params = { file_path: 'relative.txt', content: 'test' };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
|
||||
});
|
||||
@@ -391,7 +418,7 @@ describe('WriteFileTool', () => {
|
||||
it('should return error if params are invalid (path outside root)', async () => {
|
||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||
const params = { file_path: outsidePath, content: 'test' };
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Error: File path must be within the root directory/,
|
||||
@@ -409,7 +436,7 @@ describe('WriteFileTool', () => {
|
||||
throw readError;
|
||||
});
|
||||
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
expect(result.llmContent).toMatch(/Error checking existing file/);
|
||||
expect(result.returnDisplay).toMatch(
|
||||
/Error checking existing file: Simulated read error for execute/,
|
||||
@@ -427,16 +454,20 @@ describe('WriteFileTool', () => {
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(
|
||||
/Successfully created and wrote to new file/,
|
||||
@@ -477,12 +508,15 @@ describe('WriteFileTool', () => {
|
||||
|
||||
const params = { file_path: filePath, content: proposedContent };
|
||||
|
||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
const result = await tool.execute(params, new AbortController().signal);
|
||||
const result = await tool.execute(params, abortSignal);
|
||||
|
||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||
initialContent,
|
||||
@@ -492,6 +526,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
|
||||
@@ -513,12 +548,15 @@ describe('WriteFileTool', () => {
|
||||
|
||||
const params = { file_path: filePath, content };
|
||||
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
|
||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
||||
const confirmDetails = await tool.shouldConfirmExecute(
|
||||
params,
|
||||
abortSignal,
|
||||
);
|
||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
}
|
||||
|
||||
await tool.execute(params, new AbortController().signal);
|
||||
await tool.execute(params, abortSignal);
|
||||
|
||||
expect(fs.existsSync(dirPath)).toBe(true);
|
||||
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
|
||||
|
||||
@@ -141,6 +141,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
params: WriteFileToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
||||
return false;
|
||||
@@ -154,6 +155,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
@@ -193,7 +195,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
|
||||
async execute(
|
||||
params: WriteFileToolParams,
|
||||
_signal: AbortSignal,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
@@ -206,6 +208,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
const correctedContentResult = await this._getCorrectedFileContent(
|
||||
params.file_path,
|
||||
params.content,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (correctedContentResult.error) {
|
||||
@@ -277,6 +280,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
private async _getCorrectedFileContent(
|
||||
filePath: string,
|
||||
proposedContent: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<GetCorrectedFileContentResult> {
|
||||
let originalContent = '';
|
||||
let fileExists = false;
|
||||
@@ -316,6 +320,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
file_path: filePath,
|
||||
},
|
||||
this.client,
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
} else {
|
||||
@@ -323,6 +328,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
this.client,
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
return { originalContent, correctedContent, fileExists };
|
||||
|
||||
Reference in New Issue
Block a user