From dae67983a8182fb9ea2650ee09bc96a64d024e4b Mon Sep 17 00:00:00 2001 From: nityam Date: Tue, 24 Feb 2026 04:10:55 +0530 Subject: [PATCH] fix(a2a-server): Remove unsafe type assertions in agent (#19723) --- packages/a2a-server/src/agent/executor.ts | 19 ++- packages/a2a-server/src/agent/task.test.ts | 10 +- packages/a2a-server/src/agent/task.ts | 140 +++++++++++++----- packages/a2a-server/src/types.ts | 53 ++++++- .../a2a-server/src/utils/testing_utils.ts | 1 + 5 files changed, 168 insertions(+), 55 deletions(-) diff --git a/packages/a2a-server/src/agent/executor.ts b/packages/a2a-server/src/agent/executor.ts index b0522a945f..e2287a2562 100644 --- a/packages/a2a-server/src/agent/executor.ts +++ b/packages/a2a-server/src/agent/executor.ts @@ -29,6 +29,8 @@ import { CoderAgentEvent, getPersistedState, setPersistedState, + getContextIdFromMetadata, + getAgentSettingsFromMetadata, } from '../types.js'; import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js'; import { loadSettings } from '../config/settings.js'; @@ -117,8 +119,7 @@ export class CoderAgentExecutor implements AgentExecutor { const agentSettings = persistedState._agentSettings; const config = await this.getConfig(agentSettings, sdkTask.id); const contextId: string = - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - (metadata['_contextId'] as string) || sdkTask.contextId; + getContextIdFromMetadata(metadata) || sdkTask.contextId; const runtimeTask = await Task.create( sdkTask.id, contextId, @@ -141,8 +142,10 @@ export class CoderAgentExecutor implements AgentExecutor { agentSettingsInput?: AgentSettings, eventBus?: ExecutionEventBus, ): Promise { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const agentSettings = agentSettingsInput || ({} as AgentSettings); + const agentSettings: AgentSettings = agentSettingsInput || { + kind: CoderAgentEvent.StateAgentSettingsEvent, + workspacePath: process.cwd(), + }; const config = await this.getConfig(agentSettings, taskId); const runtimeTask = await Task.create( taskId, @@ -292,8 +295,7 @@ export class CoderAgentExecutor implements AgentExecutor { const contextId: string = userMessage.contextId || sdkTask?.contextId || - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - (sdkTask?.metadata?.['_contextId'] as string) || + getContextIdFromMetadata(sdkTask?.metadata) || uuidv4(); logger.info( @@ -388,10 +390,7 @@ export class CoderAgentExecutor implements AgentExecutor { } } else { logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`); - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const agentSettings = userMessage.metadata?.[ - 'coderAgent' - ] as AgentSettings; + const agentSettings = getAgentSettingsFromMetadata(userMessage.metadata); try { wrapper = await this.createTask( taskId, diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index 39cfe5eb74..81987a780b 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -513,7 +513,10 @@ describe('Task', () => { { request: { callId: '1' }, status: 'awaiting_approval', - confirmationDetails: { onConfirm: onConfirmSpy }, + confirmationDetails: { + type: 'edit', + onConfirm: onConfirmSpy, + }, }, ] as unknown as ToolCall[]; @@ -533,7 +536,10 @@ describe('Task', () => { { request: { callId: '1' }, status: 'awaiting_approval', - confirmationDetails: { onConfirm: onConfirmSpy }, + confirmationDetails: { + type: 'edit', + onConfirm: onConfirmSpy, + }, }, ] as unknown as ToolCall[]; diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index bc8cd121a9..c91ef72781 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -59,6 +59,33 @@ import type { PartUnion, Part as genAiPart } from '@google/genai'; type UnionKeys = T extends T ? keyof T : never; +type ConfirmationType = ToolCallConfirmationDetails['type']; + +const VALID_CONFIRMATION_TYPES: readonly ConfirmationType[] = [ + 'edit', + 'exec', + 'mcp', + 'info', + 'ask_user', + 'exit_plan_mode', +] as const; + +function isToolCallConfirmationDetails( + value: unknown, +): value is ToolCallConfirmationDetails { + if ( + typeof value !== 'object' || + value === null || + !('onConfirm' in value) || + typeof value.onConfirm !== 'function' || + !('type' in value) || + typeof value.type !== 'string' + ) { + return false; + } + return (VALID_CONFIRMATION_TYPES as readonly string[]).includes(value.type); +} + export class Task { id: string; contextId: string; @@ -376,11 +403,10 @@ export class Task { } if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { - this.pendingToolConfirmationDetails.set( - tc.request.callId, - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - tc.confirmationDetails as ToolCallConfirmationDetails, - ); + const details = tc.confirmationDetails; + if (isToolCallConfirmationDetails(details)) { + this.pendingToolConfirmationDetails.set(tc.request.callId, details); + } } // Only send an update if the status has actually changed. @@ -412,11 +438,12 @@ export class Task { ); toolCalls.forEach((tc: ToolCall) => { if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises, @typescript-eslint/no-unsafe-type-assertion - (tc.confirmationDetails as ToolCallConfirmationDetails).onConfirm( - ToolConfirmationOutcome.ProceedOnce, - ); - this.pendingToolConfirmationDetails.delete(tc.request.callId); + const details = tc.confirmationDetails; + if (isToolCallConfirmationDetails(details)) { + // eslint-disable-next-line @typescript-eslint/no-floating-promises + details.onConfirm(ToolConfirmationOutcome.ProceedOnce); + this.pendingToolConfirmationDetails.delete(tc.request.callId); + } } }); return; @@ -466,15 +493,13 @@ export class Task { T extends ToolCall | AnyDeclarativeTool, K extends UnionKeys, >(from: T, ...fields: K[]): Partial { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const ret = {} as Pick; + const ret: Partial = {}; for (const field of fields) { - if (field in from) { + if (field in from && from[field] !== undefined) { ret[field] = from[field]; } } - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - return ret as Partial; + return ret; } private toolStatusMessage( @@ -485,8 +510,11 @@ export class Task { const messageParts: Part[] = []; // Create a serializable version of the ToolCall (pick necessary - // properties/avoid methods causing circular reference errors) - const serializableToolCall: Partial = this._pickFields( + // properties/avoid methods causing circular reference errors). + // Type allows tool to be Partial for serialization. + const serializableToolCall: Partial> & { + tool?: Partial; + } = this._pickFields( tc, 'request', 'status', @@ -496,8 +524,7 @@ export class Task { ); if (tc.tool) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - serializableToolCall.tool = this._pickFields( + const toolFields = this._pickFields( tc.tool, 'name', 'displayName', @@ -507,7 +534,8 @@ export class Task { 'canUpdateOutput', 'schema', 'parameterSchema', - ) as AnyDeclarativeTool; + ); + serializableToolCall.tool = toolFields; } messageParts.push({ @@ -530,8 +558,15 @@ export class Task { old_string: string, new_string: string, ): Promise { + // Validate path to prevent path traversal vulnerabilities + const resolvedPath = path.resolve(this.config.getTargetDir(), file_path); + const pathError = this.config.validatePathAccess(resolvedPath, 'read'); + if (pathError) { + throw new Error(`Path validation failed: ${pathError}`); + } + try { - const currentContent = await fs.readFile(file_path, 'utf8'); + const currentContent = await fs.readFile(resolvedPath, 'utf8'); return this._applyReplacement( currentContent, old_string, @@ -625,15 +660,32 @@ export class Task { request.args['old_string'] && request.args['new_string'] ) { - const newContent = await this.getProposedContent( - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - request.args['file_path'] as string, - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - request.args['old_string'] as string, - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - request.args['new_string'] as string, - ); - return { ...request, args: { ...request.args, newContent } }; + const filePath = request.args['file_path']; + const oldString = request.args['old_string']; + const newString = request.args['new_string']; + if ( + typeof filePath === 'string' && + typeof oldString === 'string' && + typeof newString === 'string' + ) { + // Resolve and validate path to prevent path traversal (user-controlled file_path). + const resolvedPath = path.resolve( + this.config.getTargetDir(), + filePath, + ); + const pathError = this.config.validatePathAccess( + resolvedPath, + 'read', + ); + if (!pathError) { + const newContent = await this.getProposedContent( + resolvedPath, + oldString, + newString, + ); + return { ...request, args: { ...request.args, newContent } }; + } + } } return request; }), @@ -725,10 +777,17 @@ export class Task { break; case GeminiEventType.Error: default: { - // Block scope for lexical declaration - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const errorEvent = event as ServerGeminiErrorEvent; // Type assertion - const errorMessage = errorEvent.value?.error + // Use type guard instead of unsafe type assertion + let errorEvent: ServerGeminiErrorEvent | undefined; + if ( + event.type === GeminiEventType.Error && + event.value && + typeof event.value === 'object' && + 'error' in event.value + ) { + errorEvent = event; + } + const errorMessage = errorEvent?.value?.error ? getErrorMessage(errorEvent.value.error) : 'Unknown error from LLM stream'; logger.error( @@ -737,7 +796,7 @@ export class Task { ); let errMessage = `Unknown error from LLM stream: ${JSON.stringify(event)}`; - if (errorEvent.value?.error) { + if (errorEvent?.value?.error) { errMessage = parseAndFormatApiError(errorEvent.value.error); } this.cancelPendingTools(`LLM stream error: ${errorMessage}`); @@ -814,12 +873,11 @@ export class Task { // If `edit` tool call, pass updated payload if presesent if (confirmationDetails.type === 'edit') { - const payload = part.data['newContent'] - ? ({ - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - newContent: part.data['newContent'] as string, - } as ToolConfirmationPayload) - : undefined; + const newContent = part.data['newContent']; + const payload = + typeof newContent === 'string' + ? ({ newContent } as ToolConfirmationPayload) + : undefined; this.skipFinalTrueAfterInlineEdit = !!payload; try { await confirmationDetails.onConfirm(confirmationOutcome, payload); diff --git a/packages/a2a-server/src/types.ts b/packages/a2a-server/src/types.ts index 0ed6a67994..bce233c9dd 100644 --- a/packages/a2a-server/src/types.ts +++ b/packages/a2a-server/src/types.ts @@ -122,11 +122,60 @@ export type PersistedTaskMetadata = { [k: string]: unknown }; export const METADATA_KEY = '__persistedState'; +function isAgentSettings(value: unknown): value is AgentSettings { + return ( + typeof value === 'object' && + value !== null && + 'kind' in value && + value.kind === CoderAgentEvent.StateAgentSettingsEvent && + 'workspacePath' in value && + typeof value.workspacePath === 'string' + ); +} + +function isPersistedStateMetadata( + value: unknown, +): value is PersistedStateMetadata { + return ( + typeof value === 'object' && + value !== null && + '_agentSettings' in value && + '_taskState' in value && + isAgentSettings(value._agentSettings) + ); +} + export function getPersistedState( metadata: PersistedTaskMetadata, ): PersistedStateMetadata | undefined { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - return metadata?.[METADATA_KEY] as PersistedStateMetadata | undefined; + const state = metadata?.[METADATA_KEY]; + if (isPersistedStateMetadata(state)) { + return state; + } + return undefined; +} + +export function getContextIdFromMetadata( + metadata: PersistedTaskMetadata | undefined, +): string | undefined { + if (!metadata) { + return undefined; + } + const contextId = metadata['_contextId']; + return typeof contextId === 'string' ? contextId : undefined; +} + +export function getAgentSettingsFromMetadata( + metadata: PersistedTaskMetadata | undefined, +): AgentSettings | undefined { + if (!metadata) { + return undefined; + } + const coderAgent = metadata['coderAgent']; + if (isAgentSettings(coderAgent)) { + return coderAgent; + } + return undefined; } export function setPersistedState( diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 86d0d4a4bd..9cb0657c7a 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -71,6 +71,7 @@ export function createMockConfig( getMcpServers: vi.fn().mockReturnValue({}), }), getGitService: vi.fn(), + validatePathAccess: vi.fn().mockReturnValue(undefined), ...overrides, } as unknown as Config; mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());