diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index d24179d056..2855888288 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -1406,6 +1406,268 @@ describe('CoreToolScheduler request queueing', () => { }); }); +describe('CoreToolScheduler Sequential Execution', () => { + it('should execute tool calls in a batch sequentially', async () => { + // Arrange + let firstCallFinished = false; + const executeFn = vi + .fn() + .mockImplementation(async (args: { call: number }) => { + if (args.call === 1) { + // First call, wait for a bit to simulate work + await new Promise((resolve) => setTimeout(resolve, 50)); + firstCallFinished = true; + return { llmContent: 'First call done' }; + } + if (args.call === 2) { + // Second call, should only happen after the first is finished + if (!firstCallFinished) { + throw new Error( + 'Second tool call started before the first one finished!', + ); + } + return { llmContent: 'Second call done' }; + } + return { llmContent: 'default' }; + }); + + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + const declarativeTool = mockTool; + + const mockToolRegistry = { + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getToolRegistry: () => mockToolRegistry, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const requests = [ + { + callId: '1', + name: 'mockTool', + args: { call: 1 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + { + callId: '2', + name: 'mockTool', + args: { call: 2 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + ]; + + // Act + await scheduler.schedule(requests, abortController.signal); + + // Assert + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + // Check that execute was called twice + expect(executeFn).toHaveBeenCalledTimes(2); + + // Check the order of calls + const calls = executeFn.mock.calls; + expect(calls[0][0]).toEqual({ call: 1 }); + expect(calls[1][0]).toEqual({ call: 2 }); + + // The onAllToolCallsComplete should be called once with both results + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls).toHaveLength(2); + expect(completedCalls[0].status).toBe('success'); + expect(completedCalls[1].status).toBe('success'); + }); + + it('should cancel subsequent tools when the signal is aborted.', async () => { + // Arrange + const abortController = new AbortController(); + let secondCallStarted = false; + + const executeFn = vi + .fn() + .mockImplementation(async (args: { call: number }) => { + if (args.call === 1) { + return { llmContent: 'First call done' }; + } + if (args.call === 2) { + secondCallStarted = true; + // This call will be cancelled while it's "running". + await new Promise((resolve) => setTimeout(resolve, 100)); + // It should not return a value because it will be cancelled. + return { llmContent: 'Second call should not complete' }; + } + if (args.call === 3) { + return { llmContent: 'Third call done' }; + } + return { llmContent: 'default' }; + }); + + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + const declarativeTool = mockTool; + + const mockToolRegistry = { + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.YOLO, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getToolRegistry: () => mockToolRegistry, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const requests = [ + { + callId: '1', + name: 'mockTool', + args: { call: 1 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + { + callId: '2', + name: 'mockTool', + args: { call: 2 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + { + callId: '3', + name: 'mockTool', + args: { call: 3 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + ]; + + // Act + const schedulePromise = scheduler.schedule( + requests, + abortController.signal, + ); + + // Wait for the second call to start, then abort. + await vi.waitFor(() => { + expect(secondCallStarted).toBe(true); + }); + abortController.abort(); + + await schedulePromise; + + // Assert + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + // Check that execute was called for all three tools initially + expect(executeFn).toHaveBeenCalledTimes(3); + expect(executeFn).toHaveBeenCalledWith({ call: 1 }); + expect(executeFn).toHaveBeenCalledWith({ call: 2 }); + expect(executeFn).toHaveBeenCalledWith({ call: 3 }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls).toHaveLength(3); + + const call1 = completedCalls.find((c) => c.request.callId === '1'); + const call2 = completedCalls.find((c) => c.request.callId === '2'); + const call3 = completedCalls.find((c) => c.request.callId === '3'); + + expect(call1?.status).toBe('success'); + expect(call2?.status).toBe('cancelled'); + expect(call3?.status).toBe('cancelled'); + }); +}); + describe('truncateAndSaveToFile', () => { const mockWriteFile = vi.mocked(fs.writeFile); const THRESHOLD = 40_000; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index e1972a671a..d1d7829871 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -819,7 +819,7 @@ export class CoreToolScheduler { ); } } - this.attemptExecutionOfScheduledCalls(signal); + await this.attemptExecutionOfScheduledCalls(signal); void this.checkAndNotifyCompletion(); } finally { this.isScheduling = false; @@ -894,7 +894,7 @@ export class CoreToolScheduler { } this.setStatusInternal(callId, 'scheduled'); } - this.attemptExecutionOfScheduledCalls(signal); + await this.attemptExecutionOfScheduledCalls(signal); } /** @@ -940,7 +940,9 @@ export class CoreToolScheduler { }); } - private attemptExecutionOfScheduledCalls(signal: AbortSignal): void { + private async attemptExecutionOfScheduledCalls( + signal: AbortSignal, + ): Promise { const allCallsFinalOrScheduled = this.toolCalls.every( (call) => call.status === 'scheduled' || @@ -954,8 +956,8 @@ export class CoreToolScheduler { (call) => call.status === 'scheduled', ); - callsToExecute.forEach((toolCall) => { - if (toolCall.status !== 'scheduled') return; + for (const toolCall of callsToExecute) { + if (toolCall.status !== 'scheduled') continue; const scheduledCall = toolCall; const { callId, name: toolName } = scheduledCall.request; @@ -1007,107 +1009,106 @@ export class CoreToolScheduler { ); } - promise - .then(async (toolResult: ToolResult) => { - if (signal.aborted) { - this.setStatusInternal( - callId, - 'cancelled', - 'User cancelled tool execution.', - ); - return; - } + try { + const toolResult: ToolResult = await promise; + if (signal.aborted) { + this.setStatusInternal( + callId, + 'cancelled', + 'User cancelled tool execution.', + ); + continue; + } - if (toolResult.error === undefined) { - let content = toolResult.llmContent; - let outputFile: string | undefined = undefined; - const contentLength = - typeof content === 'string' ? content.length : undefined; - if ( - typeof content === 'string' && - toolName === ShellTool.Name && - this.config.getEnableToolOutputTruncation() && - this.config.getTruncateToolOutputThreshold() > 0 && - this.config.getTruncateToolOutputLines() > 0 - ) { - const originalContentLength = content.length; - const threshold = this.config.getTruncateToolOutputThreshold(); - const lines = this.config.getTruncateToolOutputLines(); - const truncatedResult = await truncateAndSaveToFile( - content, - callId, - this.config.storage.getProjectTempDir(), - threshold, - lines, - ); - content = truncatedResult.content; - outputFile = truncatedResult.outputFile; - - if (outputFile) { - logToolOutputTruncated( - this.config, - new ToolOutputTruncatedEvent( - scheduledCall.request.prompt_id, - { - toolName, - originalContentLength, - truncatedContentLength: content.length, - threshold, - lines, - }, - ), - ); - } - } - - const response = convertToFunctionResponse( - toolName, - callId, + if (toolResult.error === undefined) { + let content = toolResult.llmContent; + let outputFile: string | undefined = undefined; + const contentLength = + typeof content === 'string' ? content.length : undefined; + if ( + typeof content === 'string' && + toolName === ShellTool.Name && + this.config.getEnableToolOutputTruncation() && + this.config.getTruncateToolOutputThreshold() > 0 && + this.config.getTruncateToolOutputLines() > 0 + ) { + const originalContentLength = content.length; + const threshold = this.config.getTruncateToolOutputThreshold(); + const lines = this.config.getTruncateToolOutputLines(); + const truncatedResult = await truncateAndSaveToFile( content, - ); - const successResponse: ToolCallResponseInfo = { callId, - responseParts: response, - resultDisplay: toolResult.returnDisplay, - error: undefined, - errorType: undefined, - outputFile, - contentLength, - }; - this.setStatusInternal(callId, 'success', successResponse); - } else { - // It is a failure - const error = new Error(toolResult.error.message); - const errorResponse = createErrorResponse( + this.config.storage.getProjectTempDir(), + threshold, + lines, + ); + content = truncatedResult.content; + outputFile = truncatedResult.outputFile; + + if (outputFile) { + logToolOutputTruncated( + this.config, + new ToolOutputTruncatedEvent( + scheduledCall.request.prompt_id, + { + toolName, + originalContentLength, + truncatedContentLength: content.length, + threshold, + lines, + }, + ), + ); + } + } + + const response = convertToFunctionResponse( + toolName, + callId, + content, + ); + const successResponse: ToolCallResponseInfo = { + callId, + responseParts: response, + resultDisplay: toolResult.returnDisplay, + error: undefined, + errorType: undefined, + outputFile, + contentLength, + }; + this.setStatusInternal(callId, 'success', successResponse); + } else { + // It is a failure + const error = new Error(toolResult.error.message); + const errorResponse = createErrorResponse( + scheduledCall.request, + error, + toolResult.error.type, + ); + this.setStatusInternal(callId, 'error', errorResponse); + } + } catch (executionError: unknown) { + if (signal.aborted) { + this.setStatusInternal( + callId, + 'cancelled', + 'User cancelled tool execution.', + ); + } else { + this.setStatusInternal( + callId, + 'error', + createErrorResponse( scheduledCall.request, - error, - toolResult.error.type, - ); - this.setStatusInternal(callId, 'error', errorResponse); - } - }) - .catch((executionError: Error) => { - if (signal.aborted) { - this.setStatusInternal( - callId, - 'cancelled', - 'User cancelled tool execution.', - ); - } else { - this.setStatusInternal( - callId, - 'error', - createErrorResponse( - scheduledCall.request, - executionError instanceof Error - ? executionError - : new Error(String(executionError)), - ToolErrorType.UNHANDLED_EXCEPTION, - ), - ); - } - }); - }); + executionError instanceof Error + ? executionError + : new Error(String(executionError)), + ToolErrorType.UNHANDLED_EXCEPTION, + ), + ); + } + } + } } } diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 6cdd817727..1e7f455955 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -23,8 +23,6 @@ import { setSimulate429 } from '../utils/testUtils.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { AuthType } from './contentGenerator.js'; import { type RetryOptions } from '../utils/retry.js'; -import type { ToolRegistry } from '../tools/tool-registry.js'; -import { Kind } from '../tools/tools.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; // Mock fs module to prevent actual file system operations during tests @@ -1353,259 +1351,6 @@ describe('GeminiChat', () => { expect(turn4.parts[0].text).toBe('second response'); }); - describe('stopBeforeSecondMutator', () => { - beforeEach(() => { - // Common setup for these tests: mock the tool registry. - const mockToolRegistry = { - getTool: vi.fn((toolName: string) => { - if (toolName === 'edit') { - return { kind: Kind.Edit }; - } - return { kind: Kind.Other }; - }), - } as unknown as ToolRegistry; - vi.mocked(mockConfig.getToolRegistry).mockReturnValue(mockToolRegistry); - }); - - it('should stop streaming before a second mutator tool call', async () => { - const responses = [ - { - candidates: [ - { content: { role: 'model', parts: [{ text: 'First part. ' }] } }, - ], - }, - { - candidates: [ - { - content: { - role: 'model', - parts: [{ functionCall: { name: 'edit', args: {} } }], - }, - }, - ], - }, - { - candidates: [ - { - content: { - role: 'model', - parts: [{ functionCall: { name: 'fetch', args: {} } }], - }, - }, - ], - }, - // This chunk contains the second mutator and should be clipped. - { - candidates: [ - { - content: { - role: 'model', - parts: [ - { functionCall: { name: 'edit', args: {} } }, - { text: 'some trailing text' }, - ], - }, - }, - ], - }, - // This chunk should never be reached. - { - candidates: [ - { - content: { - role: 'model', - parts: [{ text: 'This should not appear.' }], - }, - }, - ], - }, - ] as unknown as GenerateContentResponse[]; - - vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( - (async function* () { - for (const response of responses) { - yield response; - } - })(), - ); - - const stream = await chat.sendMessageStream( - 'test-model', - { message: 'test message' }, - 'prompt-id-mutator-test', - ); - for await (const _ of stream) { - // Consume the stream to trigger history recording. - } - - const history = chat.getHistory(); - expect(history.length).toBe(2); - - const modelTurn = history[1]!; - expect(modelTurn.role).toBe('model'); - expect(modelTurn?.parts?.length).toBe(3); - expect(modelTurn?.parts![0]!.text).toBe('First part. '); - expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); - expect(modelTurn.parts![2]!.functionCall?.name).toBe('fetch'); - }); - - it('should not stop streaming if only one mutator is present', async () => { - const responses = [ - { - candidates: [ - { content: { role: 'model', parts: [{ text: 'Part 1. ' }] } }, - ], - }, - { - candidates: [ - { - content: { - role: 'model', - parts: [{ functionCall: { name: 'edit', args: {} } }], - }, - }, - ], - }, - { - candidates: [ - { - content: { - role: 'model', - parts: [{ text: 'Part 2.' }], - }, - finishReason: 'STOP', - }, - ], - }, - ] as unknown as GenerateContentResponse[]; - - vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( - (async function* () { - for (const response of responses) { - yield response; - } - })(), - ); - - const stream = await chat.sendMessageStream( - 'test-model', - { message: 'test message' }, - 'prompt-id-one-mutator', - ); - for await (const _ of stream) { - /* consume */ - } - - const history = chat.getHistory(); - const modelTurn = history[1]!; - expect(modelTurn?.parts?.length).toBe(3); - expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); - expect(modelTurn.parts![2]!.text).toBe('Part 2.'); - }); - - it('should clip the chunk containing the second mutator, preserving prior parts', async () => { - const responses = [ - { - candidates: [ - { - content: { - role: 'model', - parts: [{ functionCall: { name: 'edit', args: {} } }], - }, - }, - ], - }, - // This chunk has a valid part before the second mutator. - // The valid part should be kept, the rest of the chunk discarded. - { - candidates: [ - { - content: { - role: 'model', - parts: [ - { text: 'Keep this text. ' }, - { functionCall: { name: 'edit', args: {} } }, - { text: 'Discard this text.' }, - ], - }, - finishReason: 'STOP', - }, - ], - }, - ] as unknown as GenerateContentResponse[]; - - const stream = (async function* () { - for (const response of responses) { - yield response; - } - })(); - - vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( - stream, - ); - - const resultStream = await chat.sendMessageStream( - 'test-model', - { message: 'test' }, - 'prompt-id-clip-chunk', - ); - for await (const _ of resultStream) { - /* consume */ - } - - const history = chat.getHistory(); - const modelTurn = history[1]!; - expect(modelTurn?.parts?.length).toBe(2); - expect(modelTurn.parts![0]!.functionCall?.name).toBe('edit'); - expect(modelTurn.parts![1]!.text).toBe('Keep this text. '); - }); - - it('should handle two mutators in the same chunk (parallel call scenario)', async () => { - const responses = [ - { - candidates: [ - { - content: { - role: 'model', - parts: [ - { text: 'Some text. ' }, - { functionCall: { name: 'edit', args: {} } }, - { functionCall: { name: 'edit', args: {} } }, - ], - }, - finishReason: 'STOP', - }, - ], - }, - ] as unknown as GenerateContentResponse[]; - - const stream = (async function* () { - for (const response of responses) { - yield response; - } - })(); - - vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( - stream, - ); - - const resultStream = await chat.sendMessageStream( - 'test-model', - { message: 'test' }, - 'prompt-id-parallel-mutators', - ); - for await (const _ of resultStream) { - /* consume */ - } - - const history = chat.getHistory(); - const modelTurn = history[1]!; - expect(modelTurn?.parts?.length).toBe(2); - expect(modelTurn.parts![0]!.text).toBe('Some text. '); - expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); - }); - }); - describe('Model Resolution', () => { const mockResponse = { candidates: [ diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index f172387839..45b37c9aeb 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -7,14 +7,13 @@ // DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug // where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090 -import { +import type { GenerateContentResponse, - type Content, - type GenerateContentConfig, - type SendMessageParameters, - type Part, - type Tool, - FinishReason, + Content, + GenerateContentConfig, + SendMessageParameters, + Part, + Tool, } from '@google/genai'; import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; @@ -24,7 +23,7 @@ import { DEFAULT_GEMINI_FLASH_MODEL, getEffectiveModel, } from '../config/models.js'; -import { hasCycleInSchema, MUTATOR_KINDS } from '../tools/tools.js'; +import { hasCycleInSchema } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; import type { CompletedToolCall } from './coreToolScheduler.js'; import { @@ -496,7 +495,7 @@ export class GeminiChat { let hasToolCall = false; let hasFinishReason = false; - for await (const chunk of this.stopBeforeSecondMutator(streamResponse)) { + for await (const chunk of streamResponse) { hasFinishReason = chunk?.candidates?.some((candidate) => candidate.finishReason) ?? false; if (isValidResponse(chunk)) { @@ -641,64 +640,6 @@ export class GeminiChat { }); } } - - /** - * Truncates the chunkStream right before the second function call to a - * function that mutates state. This may involve trimming parts from a chunk - * as well as omtting some chunks altogether. - * - * We do this because it improves tool call quality if the model gets - * feedback from one mutating function call before it makes the next one. - */ - private async *stopBeforeSecondMutator( - chunkStream: AsyncGenerator, - ): AsyncGenerator { - let foundMutatorFunctionCall = false; - - for await (const chunk of chunkStream) { - const candidate = chunk.candidates?.[0]; - const content = candidate?.content; - if (!candidate || !content?.parts) { - yield chunk; - continue; - } - - const truncatedParts: Part[] = []; - for (const part of content.parts) { - if (this.isMutatorFunctionCall(part)) { - if (foundMutatorFunctionCall) { - // This is the second mutator call. - // Truncate and return immedaitely. - const newChunk = new GenerateContentResponse(); - newChunk.candidates = [ - { - ...candidate, - content: { - ...content, - parts: truncatedParts, - }, - finishReason: FinishReason.STOP, - }, - ]; - yield newChunk; - return; - } - foundMutatorFunctionCall = true; - } - truncatedParts.push(part); - } - - yield chunk; - } - } - - private isMutatorFunctionCall(part: Part): boolean { - if (!part?.functionCall?.name) { - return false; - } - const tool = this.config.getToolRegistry().getTool(part.functionCall.name); - return !!tool && MUTATOR_KINDS.includes(tool.kind); - } } /** Visible for Testing */