fix(a2a-server): Remove unsafe type assertions in agent (#19723)

This commit is contained in:
nityam
2026-02-24 04:10:55 +05:30
committed by GitHub
parent 70856d5a6e
commit dae67983a8
5 changed files with 168 additions and 55 deletions

View File

@@ -29,6 +29,8 @@ import {
CoderAgentEvent,
getPersistedState,
setPersistedState,
getContextIdFromMetadata,
getAgentSettingsFromMetadata,
} from '../types.js';
import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js';
import { loadSettings } from '../config/settings.js';
@@ -117,8 +119,7 @@ export class CoderAgentExecutor implements AgentExecutor {
const agentSettings = persistedState._agentSettings;
const config = await this.getConfig(agentSettings, sdkTask.id);
const contextId: string =
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(metadata['_contextId'] as string) || sdkTask.contextId;
getContextIdFromMetadata(metadata) || sdkTask.contextId;
const runtimeTask = await Task.create(
sdkTask.id,
contextId,
@@ -141,8 +142,10 @@ export class CoderAgentExecutor implements AgentExecutor {
agentSettingsInput?: AgentSettings,
eventBus?: ExecutionEventBus,
): Promise<TaskWrapper> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const agentSettings = agentSettingsInput || ({} as AgentSettings);
const agentSettings: AgentSettings = agentSettingsInput || {
kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: process.cwd(),
};
const config = await this.getConfig(agentSettings, taskId);
const runtimeTask = await Task.create(
taskId,
@@ -292,8 +295,7 @@ export class CoderAgentExecutor implements AgentExecutor {
const contextId: string =
userMessage.contextId ||
sdkTask?.contextId ||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(sdkTask?.metadata?.['_contextId'] as string) ||
getContextIdFromMetadata(sdkTask?.metadata) ||
uuidv4();
logger.info(
@@ -388,10 +390,7 @@ export class CoderAgentExecutor implements AgentExecutor {
}
} else {
logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`);
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const agentSettings = userMessage.metadata?.[
'coderAgent'
] as AgentSettings;
const agentSettings = getAgentSettingsFromMetadata(userMessage.metadata);
try {
wrapper = await this.createTask(
taskId,

View File

@@ -513,7 +513,10 @@ describe('Task', () => {
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
},
},
] as unknown as ToolCall[];
@@ -533,7 +536,10 @@ describe('Task', () => {
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
},
},
] as unknown as ToolCall[];

View File

@@ -59,6 +59,33 @@ import type { PartUnion, Part as genAiPart } from '@google/genai';
type UnionKeys<T> = T extends T ? keyof T : never;
type ConfirmationType = ToolCallConfirmationDetails['type'];
const VALID_CONFIRMATION_TYPES: readonly ConfirmationType[] = [
'edit',
'exec',
'mcp',
'info',
'ask_user',
'exit_plan_mode',
] as const;
function isToolCallConfirmationDetails(
value: unknown,
): value is ToolCallConfirmationDetails {
if (
typeof value !== 'object' ||
value === null ||
!('onConfirm' in value) ||
typeof value.onConfirm !== 'function' ||
!('type' in value) ||
typeof value.type !== 'string'
) {
return false;
}
return (VALID_CONFIRMATION_TYPES as readonly string[]).includes(value.type);
}
export class Task {
id: string;
contextId: string;
@@ -376,11 +403,10 @@ export class Task {
}
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
this.pendingToolConfirmationDetails.set(
tc.request.callId,
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
tc.confirmationDetails as ToolCallConfirmationDetails,
);
const details = tc.confirmationDetails;
if (isToolCallConfirmationDetails(details)) {
this.pendingToolConfirmationDetails.set(tc.request.callId, details);
}
}
// Only send an update if the status has actually changed.
@@ -412,11 +438,12 @@ export class Task {
);
toolCalls.forEach((tc: ToolCall) => {
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises, @typescript-eslint/no-unsafe-type-assertion
(tc.confirmationDetails as ToolCallConfirmationDetails).onConfirm(
ToolConfirmationOutcome.ProceedOnce,
);
this.pendingToolConfirmationDetails.delete(tc.request.callId);
const details = tc.confirmationDetails;
if (isToolCallConfirmationDetails(details)) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
details.onConfirm(ToolConfirmationOutcome.ProceedOnce);
this.pendingToolConfirmationDetails.delete(tc.request.callId);
}
}
});
return;
@@ -466,15 +493,13 @@ export class Task {
T extends ToolCall | AnyDeclarativeTool,
K extends UnionKeys<T>,
>(from: T, ...fields: K[]): Partial<T> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const ret = {} as Pick<T, K>;
const ret: Partial<T> = {};
for (const field of fields) {
if (field in from) {
if (field in from && from[field] !== undefined) {
ret[field] = from[field];
}
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return ret as Partial<T>;
return ret;
}
private toolStatusMessage(
@@ -485,8 +510,11 @@ export class Task {
const messageParts: Part[] = [];
// Create a serializable version of the ToolCall (pick necessary
// properties/avoid methods causing circular reference errors)
const serializableToolCall: Partial<ToolCall> = this._pickFields(
// properties/avoid methods causing circular reference errors).
// Type allows tool to be Partial<AnyDeclarativeTool> for serialization.
const serializableToolCall: Partial<Omit<ToolCall, 'tool'>> & {
tool?: Partial<AnyDeclarativeTool>;
} = this._pickFields(
tc,
'request',
'status',
@@ -496,8 +524,7 @@ export class Task {
);
if (tc.tool) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
serializableToolCall.tool = this._pickFields(
const toolFields = this._pickFields(
tc.tool,
'name',
'displayName',
@@ -507,7 +534,8 @@ export class Task {
'canUpdateOutput',
'schema',
'parameterSchema',
) as AnyDeclarativeTool;
);
serializableToolCall.tool = toolFields;
}
messageParts.push({
@@ -530,8 +558,15 @@ export class Task {
old_string: string,
new_string: string,
): Promise<string> {
// Validate path to prevent path traversal vulnerabilities
const resolvedPath = path.resolve(this.config.getTargetDir(), file_path);
const pathError = this.config.validatePathAccess(resolvedPath, 'read');
if (pathError) {
throw new Error(`Path validation failed: ${pathError}`);
}
try {
const currentContent = await fs.readFile(file_path, 'utf8');
const currentContent = await fs.readFile(resolvedPath, 'utf8');
return this._applyReplacement(
currentContent,
old_string,
@@ -625,15 +660,32 @@ export class Task {
request.args['old_string'] &&
request.args['new_string']
) {
const newContent = await this.getProposedContent(
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
request.args['file_path'] as string,
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
request.args['old_string'] as string,
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
request.args['new_string'] as string,
);
return { ...request, args: { ...request.args, newContent } };
const filePath = request.args['file_path'];
const oldString = request.args['old_string'];
const newString = request.args['new_string'];
if (
typeof filePath === 'string' &&
typeof oldString === 'string' &&
typeof newString === 'string'
) {
// Resolve and validate path to prevent path traversal (user-controlled file_path).
const resolvedPath = path.resolve(
this.config.getTargetDir(),
filePath,
);
const pathError = this.config.validatePathAccess(
resolvedPath,
'read',
);
if (!pathError) {
const newContent = await this.getProposedContent(
resolvedPath,
oldString,
newString,
);
return { ...request, args: { ...request.args, newContent } };
}
}
}
return request;
}),
@@ -725,10 +777,17 @@ export class Task {
break;
case GeminiEventType.Error:
default: {
// Block scope for lexical declaration
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const errorEvent = event as ServerGeminiErrorEvent; // Type assertion
const errorMessage = errorEvent.value?.error
// Use type guard instead of unsafe type assertion
let errorEvent: ServerGeminiErrorEvent | undefined;
if (
event.type === GeminiEventType.Error &&
event.value &&
typeof event.value === 'object' &&
'error' in event.value
) {
errorEvent = event;
}
const errorMessage = errorEvent?.value?.error
? getErrorMessage(errorEvent.value.error)
: 'Unknown error from LLM stream';
logger.error(
@@ -737,7 +796,7 @@ export class Task {
);
let errMessage = `Unknown error from LLM stream: ${JSON.stringify(event)}`;
if (errorEvent.value?.error) {
if (errorEvent?.value?.error) {
errMessage = parseAndFormatApiError(errorEvent.value.error);
}
this.cancelPendingTools(`LLM stream error: ${errorMessage}`);
@@ -814,12 +873,11 @@ export class Task {
// If `edit` tool call, pass updated payload if presesent
if (confirmationDetails.type === 'edit') {
const payload = part.data['newContent']
? ({
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
newContent: part.data['newContent'] as string,
} as ToolConfirmationPayload)
: undefined;
const newContent = part.data['newContent'];
const payload =
typeof newContent === 'string'
? ({ newContent } as ToolConfirmationPayload)
: undefined;
this.skipFinalTrueAfterInlineEdit = !!payload;
try {
await confirmationDetails.onConfirm(confirmationOutcome, payload);

View File

@@ -122,11 +122,60 @@ export type PersistedTaskMetadata = { [k: string]: unknown };
export const METADATA_KEY = '__persistedState';
function isAgentSettings(value: unknown): value is AgentSettings {
return (
typeof value === 'object' &&
value !== null &&
'kind' in value &&
value.kind === CoderAgentEvent.StateAgentSettingsEvent &&
'workspacePath' in value &&
typeof value.workspacePath === 'string'
);
}
function isPersistedStateMetadata(
value: unknown,
): value is PersistedStateMetadata {
return (
typeof value === 'object' &&
value !== null &&
'_agentSettings' in value &&
'_taskState' in value &&
isAgentSettings(value._agentSettings)
);
}
export function getPersistedState(
metadata: PersistedTaskMetadata,
): PersistedStateMetadata | undefined {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return metadata?.[METADATA_KEY] as PersistedStateMetadata | undefined;
const state = metadata?.[METADATA_KEY];
if (isPersistedStateMetadata(state)) {
return state;
}
return undefined;
}
export function getContextIdFromMetadata(
metadata: PersistedTaskMetadata | undefined,
): string | undefined {
if (!metadata) {
return undefined;
}
const contextId = metadata['_contextId'];
return typeof contextId === 'string' ? contextId : undefined;
}
export function getAgentSettingsFromMetadata(
metadata: PersistedTaskMetadata | undefined,
): AgentSettings | undefined {
if (!metadata) {
return undefined;
}
const coderAgent = metadata['coderAgent'];
if (isAgentSettings(coderAgent)) {
return coderAgent;
}
return undefined;
}
export function setPersistedState(

View File

@@ -71,6 +71,7 @@ export function createMockConfig(
getMcpServers: vi.fn().mockReturnValue({}),
}),
getGitService: vi.fn(),
validatePathAccess: vi.fn().mockReturnValue(undefined),
...overrides,
} as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());