Add prompt_id propagation in a2a-server task (#14581)

This commit is contained in:
Paweł Dec
2025-12-08 16:12:47 +01:00
committed by GitHub
parent 8f4f8baa81
commit 6e51bbc215
2 changed files with 91 additions and 5 deletions

View File

@@ -18,9 +18,10 @@ import {
GeminiEventType,
type Config,
type ToolCallRequestInfo,
type CompletedToolCall,
} from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
import { CoderAgentEvent } from '../types.js';
import type { ToolCall } from '@google/gemini-cli-core';
@@ -318,4 +319,87 @@ describe('Task', () => {
expect(finalCall).toBeUndefined();
});
});
describe('currentPromptId and promptCount', () => {
it('should correctly initialize and update promptId and promptCount', async () => {
const mockConfig = createMockConfig();
mockConfig.getGeminiClient = vi.fn().mockReturnValue({
sendMessageStream: vi.fn().mockReturnValue((async function* () {})()),
});
mockConfig.getSessionId = () => 'test-session-id';
const mockEventBus: ExecutionEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
// @ts-expect-error - Calling private constructor
const task = new Task(
'task-id',
'context-id',
mockConfig as Config,
mockEventBus,
);
// Initial state
expect(task.currentPromptId).toBeUndefined();
expect(task.promptCount).toBe(0);
// First user message should set prompt_id
const userMessage1 = {
userMessage: {
parts: [{ kind: 'text', text: 'hello' }],
},
} as RequestContext;
const abortController1 = new AbortController();
for await (const _ of task.acceptUserMessage(
userMessage1,
abortController1.signal,
)) {
// no-op
}
const expectedPromptId1 = 'test-session-id########0';
expect(task.promptCount).toBe(1);
expect(task.currentPromptId).toBe(expectedPromptId1);
// A new user message should generate a new prompt_id
const userMessage2 = {
userMessage: {
parts: [{ kind: 'text', text: 'world' }],
},
} as RequestContext;
const abortController2 = new AbortController();
for await (const _ of task.acceptUserMessage(
userMessage2,
abortController2.signal,
)) {
// no-op
}
const expectedPromptId2 = 'test-session-id########1';
expect(task.promptCount).toBe(2);
expect(task.currentPromptId).toBe(expectedPromptId2);
// Subsequent tool call processing should use the same prompt_id
const completedTool = {
request: { callId: 'tool-1' },
response: { responseParts: [{ text: 'tool output' }] },
} as CompletedToolCall;
const abortController3 = new AbortController();
for await (const _ of task.sendCompletedToolsToLlm(
[completedTool],
abortController3.signal,
)) {
// no-op
}
expect(task.promptCount).toBe(2);
expect(task.currentPromptId).toBe(expectedPromptId2);
});
});
});

View File

@@ -68,6 +68,8 @@ export class Task {
completedToolCalls: CompletedToolCall[];
skipFinalTrueAfterInlineEdit = false;
modelInfo?: string;
currentPromptId: string | undefined;
promptCount = 0;
// For tool waiting logic
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
@@ -859,11 +861,10 @@ export class Task {
};
// Set task state to working as we are about to call LLM
this.setTaskStateAndPublishUpdate('working', stateChange);
// TODO: Determine what it mean to have, then add a prompt ID.
yield* this.geminiClient.sendMessageStream(
llmParts,
aborted,
/*prompt_id*/ '',
completedToolCalls[0]?.request.prompt_id ?? '',
);
}
@@ -893,17 +894,18 @@ export class Task {
}
if (hasContentForLlm) {
this.currentPromptId =
this.config.getSessionId() + '########' + this.promptCount++;
logger.info('[Task] Sending new parts to LLM.');
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
// Set task state to working as we are about to call LLM
this.setTaskStateAndPublishUpdate('working', stateChange);
// TODO: Determine what it mean to have, then add a prompt ID.
yield* this.geminiClient.sendMessageStream(
llmParts,
aborted,
/*prompt_id*/ '',
this.currentPromptId,
);
} else if (anyConfirmationHandled) {
logger.info(