diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index 8b347f70e2..d514079592 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -151,6 +151,54 @@ describe('Task', () => { }, ]); }); + + it('should update modelInfo and reflect it in metadata and status updates', async () => { + const mockConfig = createMockConfig(); + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + // @ts-expect-error - Calling private constructor for test purposes. + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + + const modelInfoEvent = { + type: GeminiEventType.ModelInfo, + value: 'new-model-name', + }; + + await task.acceptAgentMessage(modelInfoEvent); + + expect(task.modelInfo).toBe('new-model-name'); + + // Check getMetadata + const metadata = await task.getMetadata(); + expect(metadata.model).toBe('new-model-name'); + + // Check status update + task.setTaskStateAndPublishUpdate( + 'working', + { kind: CoderAgentEvent.StateChangeEvent }, + 'Working...', + ); + + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + metadata: expect.objectContaining({ + model: 'new-model-name', + }), + }), + ); + }); }); describe('_schedulerToolCallsUpdate', () => { diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 12f58be8b9..1fdc1d5e60 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -66,6 +66,7 @@ export class Task { eventBus?: ExecutionEventBus; completedToolCalls: CompletedToolCall[]; skipFinalTrueAfterInlineEdit = false; + modelInfo?: string; // For tool waiting logic private pendingToolCalls: Map = new Map(); //toolCallId --> status @@ -135,7 +136,7 @@ export class Task { id: this.id, contextId: this.contextId, taskState: this.taskState, - model: this.config.getModel(), + model: this.modelInfo || this.config.getModel(), mcpServers: servers, availableTools, }; @@ -230,7 +231,7 @@ export class Task { traceId?: string; } = { coderAgent: coderAgentMessage, - model: this.config.getModel(), + model: this.modelInfo || this.config.getModel(), userTier: this.config.getUserTier(), }; @@ -647,6 +648,9 @@ export class Task { case GeminiEventType.Finished: logger.info(`[Task ${this.id}] Agent finished its turn.`); break; + case GeminiEventType.ModelInfo: + this.modelInfo = event.value; + break; case GeminiEventType.Error: default: { // Block scope for lexical declaration diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index f0c21b467b..fcd184924c 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -27,6 +27,8 @@ export function createMockConfig( getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn(), getAllToolNames: vi.fn().mockReturnValue([]), + getAllTools: vi.fn().mockReturnValue([]), + getToolsByServer: vi.fn().mockReturnValue([]), }), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), getIdeMode: vi.fn().mockReturnValue(false), @@ -57,6 +59,9 @@ export function createMockConfig( getPolicyEngine: vi.fn(), getEnableExtensionReloading: vi.fn().mockReturnValue(false), getEnableHooks: vi.fn().mockReturnValue(false), + getMcpClientManager: vi.fn().mockReturnValue({ + getMcpServers: vi.fn().mockReturnValue({}), + }), ...overrides, } as unknown as Config; mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());