fix NaN loss not caught by fast-fail check

This commit is contained in:
Andrej
2026-03-10 22:31:43 -07:00
committed by GitHub
+3 -2
View File
@@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import gc
import math
import time
from dataclasses import dataclass, asdict
@@ -565,8 +566,8 @@ while True:
train_loss_f = train_loss.item()
# Fast fail: abort if loss is exploding
if train_loss_f > 100:
# Fast fail: abort if loss is exploding or NaN
if math.isnan(train_loss_f) or train_loss_f > 100:
print("FAIL")
exit(1)