From 6e51bbc215c706109d26086449d775c8ad2093fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Dec?= <64912735+koxkox111@users.noreply.github.com> Date: Mon, 8 Dec 2025 16:12:47 +0100 Subject: [PATCH] Add prompt_id propagation in a2a-server task (#14581) --- packages/a2a-server/src/agent/task.test.ts | 86 +++++++++++++++++++++- packages/a2a-server/src/agent/task.ts | 10 ++- 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index d514079592..47448b8c49 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -18,9 +18,10 @@ import { GeminiEventType, type Config, type ToolCallRequestInfo, + type CompletedToolCall, } from '@google/gemini-cli-core'; import { createMockConfig } from '../utils/testing_utils.js'; -import type { ExecutionEventBus } from '@a2a-js/sdk/server'; +import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server'; import { CoderAgentEvent } from '../types.js'; import type { ToolCall } from '@google/gemini-cli-core'; @@ -318,4 +319,87 @@ describe('Task', () => { expect(finalCall).toBeUndefined(); }); }); + + describe('currentPromptId and promptCount', () => { + it('should correctly initialize and update promptId and promptCount', async () => { + const mockConfig = createMockConfig(); + mockConfig.getGeminiClient = vi.fn().mockReturnValue({ + sendMessageStream: vi.fn().mockReturnValue((async function* () {})()), + }); + mockConfig.getSessionId = () => 'test-session-id'; + + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + // @ts-expect-error - Calling private constructor + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + + // Initial state + expect(task.currentPromptId).toBeUndefined(); + expect(task.promptCount).toBe(0); + + // First user message should set prompt_id + const userMessage1 = { + userMessage: { + parts: [{ kind: 'text', text: 'hello' }], + }, + } as RequestContext; + const abortController1 = new AbortController(); + for await (const _ of task.acceptUserMessage( + userMessage1, + abortController1.signal, + )) { + // no-op + } + + const expectedPromptId1 = 'test-session-id########0'; + expect(task.promptCount).toBe(1); + expect(task.currentPromptId).toBe(expectedPromptId1); + + // A new user message should generate a new prompt_id + const userMessage2 = { + userMessage: { + parts: [{ kind: 'text', text: 'world' }], + }, + } as RequestContext; + const abortController2 = new AbortController(); + for await (const _ of task.acceptUserMessage( + userMessage2, + abortController2.signal, + )) { + // no-op + } + + const expectedPromptId2 = 'test-session-id########1'; + expect(task.promptCount).toBe(2); + expect(task.currentPromptId).toBe(expectedPromptId2); + + // Subsequent tool call processing should use the same prompt_id + const completedTool = { + request: { callId: 'tool-1' }, + response: { responseParts: [{ text: 'tool output' }] }, + } as CompletedToolCall; + const abortController3 = new AbortController(); + for await (const _ of task.sendCompletedToolsToLlm( + [completedTool], + abortController3.signal, + )) { + // no-op + } + + expect(task.promptCount).toBe(2); + expect(task.currentPromptId).toBe(expectedPromptId2); + }); + }); }); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 5dacad9e6f..cca753b242 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -68,6 +68,8 @@ export class Task { completedToolCalls: CompletedToolCall[]; skipFinalTrueAfterInlineEdit = false; modelInfo?: string; + currentPromptId: string | undefined; + promptCount = 0; // For tool waiting logic private pendingToolCalls: Map = new Map(); //toolCallId --> status @@ -859,11 +861,10 @@ export class Task { }; // Set task state to working as we are about to call LLM this.setTaskStateAndPublishUpdate('working', stateChange); - // TODO: Determine what it mean to have, then add a prompt ID. yield* this.geminiClient.sendMessageStream( llmParts, aborted, - /*prompt_id*/ '', + completedToolCalls[0]?.request.prompt_id ?? '', ); } @@ -893,17 +894,18 @@ export class Task { } if (hasContentForLlm) { + this.currentPromptId = + this.config.getSessionId() + '########' + this.promptCount++; logger.info('[Task] Sending new parts to LLM.'); const stateChange: StateChange = { kind: CoderAgentEvent.StateChangeEvent, }; // Set task state to working as we are about to call LLM this.setTaskStateAndPublishUpdate('working', stateChange); - // TODO: Determine what it mean to have, then add a prompt ID. yield* this.geminiClient.sendMessageStream( llmParts, aborted, - /*prompt_id*/ '', + this.currentPromptId, ); } else if (anyConfirmationHandled) { logger.info(