mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
Add prompt_id propagation in a2a-server task (#14581)
This commit is contained in:
@@ -18,9 +18,10 @@ import {
|
||||
GeminiEventType,
|
||||
type Config,
|
||||
type ToolCallRequestInfo,
|
||||
type CompletedToolCall,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type { ToolCall } from '@google/gemini-cli-core';
|
||||
|
||||
@@ -318,4 +319,87 @@ describe('Task', () => {
|
||||
expect(finalCall).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('currentPromptId and promptCount', () => {
|
||||
it('should correctly initialize and update promptId and promptCount', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
mockConfig.getGeminiClient = vi.fn().mockReturnValue({
|
||||
sendMessageStream: vi.fn().mockReturnValue((async function* () {})()),
|
||||
});
|
||||
mockConfig.getSessionId = () => 'test-session-id';
|
||||
|
||||
const mockEventBus: ExecutionEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
|
||||
// Initial state
|
||||
expect(task.currentPromptId).toBeUndefined();
|
||||
expect(task.promptCount).toBe(0);
|
||||
|
||||
// First user message should set prompt_id
|
||||
const userMessage1 = {
|
||||
userMessage: {
|
||||
parts: [{ kind: 'text', text: 'hello' }],
|
||||
},
|
||||
} as RequestContext;
|
||||
const abortController1 = new AbortController();
|
||||
for await (const _ of task.acceptUserMessage(
|
||||
userMessage1,
|
||||
abortController1.signal,
|
||||
)) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
const expectedPromptId1 = 'test-session-id########0';
|
||||
expect(task.promptCount).toBe(1);
|
||||
expect(task.currentPromptId).toBe(expectedPromptId1);
|
||||
|
||||
// A new user message should generate a new prompt_id
|
||||
const userMessage2 = {
|
||||
userMessage: {
|
||||
parts: [{ kind: 'text', text: 'world' }],
|
||||
},
|
||||
} as RequestContext;
|
||||
const abortController2 = new AbortController();
|
||||
for await (const _ of task.acceptUserMessage(
|
||||
userMessage2,
|
||||
abortController2.signal,
|
||||
)) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
const expectedPromptId2 = 'test-session-id########1';
|
||||
expect(task.promptCount).toBe(2);
|
||||
expect(task.currentPromptId).toBe(expectedPromptId2);
|
||||
|
||||
// Subsequent tool call processing should use the same prompt_id
|
||||
const completedTool = {
|
||||
request: { callId: 'tool-1' },
|
||||
response: { responseParts: [{ text: 'tool output' }] },
|
||||
} as CompletedToolCall;
|
||||
const abortController3 = new AbortController();
|
||||
for await (const _ of task.sendCompletedToolsToLlm(
|
||||
[completedTool],
|
||||
abortController3.signal,
|
||||
)) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
expect(task.promptCount).toBe(2);
|
||||
expect(task.currentPromptId).toBe(expectedPromptId2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -68,6 +68,8 @@ export class Task {
|
||||
completedToolCalls: CompletedToolCall[];
|
||||
skipFinalTrueAfterInlineEdit = false;
|
||||
modelInfo?: string;
|
||||
currentPromptId: string | undefined;
|
||||
promptCount = 0;
|
||||
|
||||
// For tool waiting logic
|
||||
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
|
||||
@@ -859,11 +861,10 @@ export class Task {
|
||||
};
|
||||
// Set task state to working as we are about to call LLM
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
// TODO: Determine what it mean to have, then add a prompt ID.
|
||||
yield* this.geminiClient.sendMessageStream(
|
||||
llmParts,
|
||||
aborted,
|
||||
/*prompt_id*/ '',
|
||||
completedToolCalls[0]?.request.prompt_id ?? '',
|
||||
);
|
||||
}
|
||||
|
||||
@@ -893,17 +894,18 @@ export class Task {
|
||||
}
|
||||
|
||||
if (hasContentForLlm) {
|
||||
this.currentPromptId =
|
||||
this.config.getSessionId() + '########' + this.promptCount++;
|
||||
logger.info('[Task] Sending new parts to LLM.');
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
// Set task state to working as we are about to call LLM
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
// TODO: Determine what it mean to have, then add a prompt ID.
|
||||
yield* this.geminiClient.sendMessageStream(
|
||||
llmParts,
|
||||
aborted,
|
||||
/*prompt_id*/ '',
|
||||
this.currentPromptId,
|
||||
);
|
||||
} else if (anyConfirmationHandled) {
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user