feat(a2a): Urgent fix - Process modelInfo agent message (#14315)

This commit is contained in:
Coco Sheng
2025-12-01 15:09:02 -05:00
committed by GitHub
parent 26f050ff10
commit 806cd112ac
3 changed files with 59 additions and 2 deletions
@@ -151,6 +151,54 @@ describe('Task', () => {
},
]);
});
it('should update modelInfo and reflect it in metadata and status updates', async () => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
// @ts-expect-error - Calling private constructor for test purposes.
const task = new Task(
'task-id',
'context-id',
mockConfig as Config,
mockEventBus,
);
const modelInfoEvent = {
type: GeminiEventType.ModelInfo,
value: 'new-model-name',
};
await task.acceptAgentMessage(modelInfoEvent);
expect(task.modelInfo).toBe('new-model-name');
// Check getMetadata
const metadata = await task.getMetadata();
expect(metadata.model).toBe('new-model-name');
// Check status update
task.setTaskStateAndPublishUpdate(
'working',
{ kind: CoderAgentEvent.StateChangeEvent },
'Working...',
);
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
metadata: expect.objectContaining({
model: 'new-model-name',
}),
}),
);
});
});
describe('_schedulerToolCallsUpdate', () => {
+6 -2
View File
@@ -66,6 +66,7 @@ export class Task {
eventBus?: ExecutionEventBus;
completedToolCalls: CompletedToolCall[];
skipFinalTrueAfterInlineEdit = false;
modelInfo?: string;
// For tool waiting logic
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
@@ -135,7 +136,7 @@ export class Task {
id: this.id,
contextId: this.contextId,
taskState: this.taskState,
model: this.config.getModel(),
model: this.modelInfo || this.config.getModel(),
mcpServers: servers,
availableTools,
};
@@ -230,7 +231,7 @@ export class Task {
traceId?: string;
} = {
coderAgent: coderAgentMessage,
model: this.config.getModel(),
model: this.modelInfo || this.config.getModel(),
userTier: this.config.getUserTier(),
};
@@ -647,6 +648,9 @@ export class Task {
case GeminiEventType.Finished:
logger.info(`[Task ${this.id}] Agent finished its turn.`);
break;
case GeminiEventType.ModelInfo:
this.modelInfo = event.value;
break;
case GeminiEventType.Error:
default: {
// Block scope for lexical declaration
@@ -27,6 +27,8 @@ export function createMockConfig(
getToolRegistry: vi.fn().mockReturnValue({
getTool: vi.fn(),
getAllToolNames: vi.fn().mockReturnValue([]),
getAllTools: vi.fn().mockReturnValue([]),
getToolsByServer: vi.fn().mockReturnValue([]),
}),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getIdeMode: vi.fn().mockReturnValue(false),
@@ -57,6 +59,9 @@ export function createMockConfig(
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false),
getMcpClientManager: vi.fn().mockReturnValue({
getMcpServers: vi.fn().mockReturnValue({}),
}),
...overrides,
} as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());