fix(a2a-server): Remove unsafe type assertions in agent (#19723)

This commit is contained in:
nityam
2026-02-24 04:10:55 +05:30
committed by GitHub
parent 70856d5a6e
commit dae67983a8
5 changed files with 168 additions and 55 deletions
+9 -10
View File
@@ -29,6 +29,8 @@ import {
CoderAgentEvent, CoderAgentEvent,
getPersistedState, getPersistedState,
setPersistedState, setPersistedState,
getContextIdFromMetadata,
getAgentSettingsFromMetadata,
} from '../types.js'; } from '../types.js';
import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js'; import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js';
import { loadSettings } from '../config/settings.js'; import { loadSettings } from '../config/settings.js';
@@ -117,8 +119,7 @@ export class CoderAgentExecutor implements AgentExecutor {
const agentSettings = persistedState._agentSettings; const agentSettings = persistedState._agentSettings;
const config = await this.getConfig(agentSettings, sdkTask.id); const config = await this.getConfig(agentSettings, sdkTask.id);
const contextId: string = const contextId: string =
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion getContextIdFromMetadata(metadata) || sdkTask.contextId;
(metadata['_contextId'] as string) || sdkTask.contextId;
const runtimeTask = await Task.create( const runtimeTask = await Task.create(
sdkTask.id, sdkTask.id,
contextId, contextId,
@@ -141,8 +142,10 @@ export class CoderAgentExecutor implements AgentExecutor {
agentSettingsInput?: AgentSettings, agentSettingsInput?: AgentSettings,
eventBus?: ExecutionEventBus, eventBus?: ExecutionEventBus,
): Promise<TaskWrapper> { ): Promise<TaskWrapper> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const agentSettings: AgentSettings = agentSettingsInput || {
const agentSettings = agentSettingsInput || ({} as AgentSettings); kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: process.cwd(),
};
const config = await this.getConfig(agentSettings, taskId); const config = await this.getConfig(agentSettings, taskId);
const runtimeTask = await Task.create( const runtimeTask = await Task.create(
taskId, taskId,
@@ -292,8 +295,7 @@ export class CoderAgentExecutor implements AgentExecutor {
const contextId: string = const contextId: string =
userMessage.contextId || userMessage.contextId ||
sdkTask?.contextId || sdkTask?.contextId ||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion getContextIdFromMetadata(sdkTask?.metadata) ||
(sdkTask?.metadata?.['_contextId'] as string) ||
uuidv4(); uuidv4();
logger.info( logger.info(
@@ -388,10 +390,7 @@ export class CoderAgentExecutor implements AgentExecutor {
} }
} else { } else {
logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`); logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`);
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const agentSettings = getAgentSettingsFromMetadata(userMessage.metadata);
const agentSettings = userMessage.metadata?.[
'coderAgent'
] as AgentSettings;
try { try {
wrapper = await this.createTask( wrapper = await this.createTask(
taskId, taskId,
+8 -2
View File
@@ -513,7 +513,10 @@ describe('Task', () => {
{ {
request: { callId: '1' }, request: { callId: '1' },
status: 'awaiting_approval', status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy }, confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
},
}, },
] as unknown as ToolCall[]; ] as unknown as ToolCall[];
@@ -533,7 +536,10 @@ describe('Task', () => {
{ {
request: { callId: '1' }, request: { callId: '1' },
status: 'awaiting_approval', status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy }, confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
},
}, },
] as unknown as ToolCall[]; ] as unknown as ToolCall[];
+99 -41
View File
@@ -59,6 +59,33 @@ import type { PartUnion, Part as genAiPart } from '@google/genai';
type UnionKeys<T> = T extends T ? keyof T : never; type UnionKeys<T> = T extends T ? keyof T : never;
type ConfirmationType = ToolCallConfirmationDetails['type'];
const VALID_CONFIRMATION_TYPES: readonly ConfirmationType[] = [
'edit',
'exec',
'mcp',
'info',
'ask_user',
'exit_plan_mode',
] as const;
function isToolCallConfirmationDetails(
value: unknown,
): value is ToolCallConfirmationDetails {
if (
typeof value !== 'object' ||
value === null ||
!('onConfirm' in value) ||
typeof value.onConfirm !== 'function' ||
!('type' in value) ||
typeof value.type !== 'string'
) {
return false;
}
return (VALID_CONFIRMATION_TYPES as readonly string[]).includes(value.type);
}
export class Task { export class Task {
id: string; id: string;
contextId: string; contextId: string;
@@ -376,11 +403,10 @@ export class Task {
} }
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
this.pendingToolConfirmationDetails.set( const details = tc.confirmationDetails;
tc.request.callId, if (isToolCallConfirmationDetails(details)) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion this.pendingToolConfirmationDetails.set(tc.request.callId, details);
tc.confirmationDetails as ToolCallConfirmationDetails, }
);
} }
// Only send an update if the status has actually changed. // Only send an update if the status has actually changed.
@@ -412,11 +438,12 @@ export class Task {
); );
toolCalls.forEach((tc: ToolCall) => { toolCalls.forEach((tc: ToolCall) => {
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises, @typescript-eslint/no-unsafe-type-assertion const details = tc.confirmationDetails;
(tc.confirmationDetails as ToolCallConfirmationDetails).onConfirm( if (isToolCallConfirmationDetails(details)) {
ToolConfirmationOutcome.ProceedOnce, // eslint-disable-next-line @typescript-eslint/no-floating-promises
); details.onConfirm(ToolConfirmationOutcome.ProceedOnce);
this.pendingToolConfirmationDetails.delete(tc.request.callId); this.pendingToolConfirmationDetails.delete(tc.request.callId);
}
} }
}); });
return; return;
@@ -466,15 +493,13 @@ export class Task {
T extends ToolCall | AnyDeclarativeTool, T extends ToolCall | AnyDeclarativeTool,
K extends UnionKeys<T>, K extends UnionKeys<T>,
>(from: T, ...fields: K[]): Partial<T> { >(from: T, ...fields: K[]): Partial<T> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const ret: Partial<T> = {};
const ret = {} as Pick<T, K>;
for (const field of fields) { for (const field of fields) {
if (field in from) { if (field in from && from[field] !== undefined) {
ret[field] = from[field]; ret[field] = from[field];
} }
} }
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion return ret;
return ret as Partial<T>;
} }
private toolStatusMessage( private toolStatusMessage(
@@ -485,8 +510,11 @@ export class Task {
const messageParts: Part[] = []; const messageParts: Part[] = [];
// Create a serializable version of the ToolCall (pick necessary // Create a serializable version of the ToolCall (pick necessary
// properties/avoid methods causing circular reference errors) // properties/avoid methods causing circular reference errors).
const serializableToolCall: Partial<ToolCall> = this._pickFields( // Type allows tool to be Partial<AnyDeclarativeTool> for serialization.
const serializableToolCall: Partial<Omit<ToolCall, 'tool'>> & {
tool?: Partial<AnyDeclarativeTool>;
} = this._pickFields(
tc, tc,
'request', 'request',
'status', 'status',
@@ -496,8 +524,7 @@ export class Task {
); );
if (tc.tool) { if (tc.tool) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const toolFields = this._pickFields(
serializableToolCall.tool = this._pickFields(
tc.tool, tc.tool,
'name', 'name',
'displayName', 'displayName',
@@ -507,7 +534,8 @@ export class Task {
'canUpdateOutput', 'canUpdateOutput',
'schema', 'schema',
'parameterSchema', 'parameterSchema',
) as AnyDeclarativeTool; );
serializableToolCall.tool = toolFields;
} }
messageParts.push({ messageParts.push({
@@ -530,8 +558,15 @@ export class Task {
old_string: string, old_string: string,
new_string: string, new_string: string,
): Promise<string> { ): Promise<string> {
// Validate path to prevent path traversal vulnerabilities
const resolvedPath = path.resolve(this.config.getTargetDir(), file_path);
const pathError = this.config.validatePathAccess(resolvedPath, 'read');
if (pathError) {
throw new Error(`Path validation failed: ${pathError}`);
}
try { try {
const currentContent = await fs.readFile(file_path, 'utf8'); const currentContent = await fs.readFile(resolvedPath, 'utf8');
return this._applyReplacement( return this._applyReplacement(
currentContent, currentContent,
old_string, old_string,
@@ -625,15 +660,32 @@ export class Task {
request.args['old_string'] && request.args['old_string'] &&
request.args['new_string'] request.args['new_string']
) { ) {
const newContent = await this.getProposedContent( const filePath = request.args['file_path'];
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const oldString = request.args['old_string'];
request.args['file_path'] as string, const newString = request.args['new_string'];
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion if (
request.args['old_string'] as string, typeof filePath === 'string' &&
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion typeof oldString === 'string' &&
request.args['new_string'] as string, typeof newString === 'string'
); ) {
return { ...request, args: { ...request.args, newContent } }; // Resolve and validate path to prevent path traversal (user-controlled file_path).
const resolvedPath = path.resolve(
this.config.getTargetDir(),
filePath,
);
const pathError = this.config.validatePathAccess(
resolvedPath,
'read',
);
if (!pathError) {
const newContent = await this.getProposedContent(
resolvedPath,
oldString,
newString,
);
return { ...request, args: { ...request.args, newContent } };
}
}
} }
return request; return request;
}), }),
@@ -725,10 +777,17 @@ export class Task {
break; break;
case GeminiEventType.Error: case GeminiEventType.Error:
default: { default: {
// Block scope for lexical declaration // Use type guard instead of unsafe type assertion
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion let errorEvent: ServerGeminiErrorEvent | undefined;
const errorEvent = event as ServerGeminiErrorEvent; // Type assertion if (
const errorMessage = errorEvent.value?.error event.type === GeminiEventType.Error &&
event.value &&
typeof event.value === 'object' &&
'error' in event.value
) {
errorEvent = event;
}
const errorMessage = errorEvent?.value?.error
? getErrorMessage(errorEvent.value.error) ? getErrorMessage(errorEvent.value.error)
: 'Unknown error from LLM stream'; : 'Unknown error from LLM stream';
logger.error( logger.error(
@@ -737,7 +796,7 @@ export class Task {
); );
let errMessage = `Unknown error from LLM stream: ${JSON.stringify(event)}`; let errMessage = `Unknown error from LLM stream: ${JSON.stringify(event)}`;
if (errorEvent.value?.error) { if (errorEvent?.value?.error) {
errMessage = parseAndFormatApiError(errorEvent.value.error); errMessage = parseAndFormatApiError(errorEvent.value.error);
} }
this.cancelPendingTools(`LLM stream error: ${errorMessage}`); this.cancelPendingTools(`LLM stream error: ${errorMessage}`);
@@ -814,12 +873,11 @@ export class Task {
// If `edit` tool call, pass updated payload if presesent // If `edit` tool call, pass updated payload if presesent
if (confirmationDetails.type === 'edit') { if (confirmationDetails.type === 'edit') {
const payload = part.data['newContent'] const newContent = part.data['newContent'];
? ({ const payload =
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion typeof newContent === 'string'
newContent: part.data['newContent'] as string, ? ({ newContent } as ToolConfirmationPayload)
} as ToolConfirmationPayload) : undefined;
: undefined;
this.skipFinalTrueAfterInlineEdit = !!payload; this.skipFinalTrueAfterInlineEdit = !!payload;
try { try {
await confirmationDetails.onConfirm(confirmationOutcome, payload); await confirmationDetails.onConfirm(confirmationOutcome, payload);
+51 -2
View File
@@ -122,11 +122,60 @@ export type PersistedTaskMetadata = { [k: string]: unknown };
export const METADATA_KEY = '__persistedState'; export const METADATA_KEY = '__persistedState';
function isAgentSettings(value: unknown): value is AgentSettings {
return (
typeof value === 'object' &&
value !== null &&
'kind' in value &&
value.kind === CoderAgentEvent.StateAgentSettingsEvent &&
'workspacePath' in value &&
typeof value.workspacePath === 'string'
);
}
function isPersistedStateMetadata(
value: unknown,
): value is PersistedStateMetadata {
return (
typeof value === 'object' &&
value !== null &&
'_agentSettings' in value &&
'_taskState' in value &&
isAgentSettings(value._agentSettings)
);
}
export function getPersistedState( export function getPersistedState(
metadata: PersistedTaskMetadata, metadata: PersistedTaskMetadata,
): PersistedStateMetadata | undefined { ): PersistedStateMetadata | undefined {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const state = metadata?.[METADATA_KEY];
return metadata?.[METADATA_KEY] as PersistedStateMetadata | undefined; if (isPersistedStateMetadata(state)) {
return state;
}
return undefined;
}
export function getContextIdFromMetadata(
metadata: PersistedTaskMetadata | undefined,
): string | undefined {
if (!metadata) {
return undefined;
}
const contextId = metadata['_contextId'];
return typeof contextId === 'string' ? contextId : undefined;
}
export function getAgentSettingsFromMetadata(
metadata: PersistedTaskMetadata | undefined,
): AgentSettings | undefined {
if (!metadata) {
return undefined;
}
const coderAgent = metadata['coderAgent'];
if (isAgentSettings(coderAgent)) {
return coderAgent;
}
return undefined;
} }
export function setPersistedState( export function setPersistedState(
@@ -71,6 +71,7 @@ export function createMockConfig(
getMcpServers: vi.fn().mockReturnValue({}), getMcpServers: vi.fn().mockReturnValue({}),
}), }),
getGitService: vi.fn(), getGitService: vi.fn(),
validatePathAccess: vi.fn().mockReturnValue(undefined),
...overrides, ...overrides,
} as unknown as Config; } as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus()); mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());