diff --git a/train.py b/train.py index 6994fb9..1378bab 100644 --- a/train.py +++ b/train.py @@ -565,8 +565,8 @@ while True: train_loss_f = train_loss.item() - # Fast fail: abort if loss is exploding - if train_loss_f > 100: + # Fast fail: abort if loss is exploding or NaN + if not train_loss_f <= 100: print("FAIL") exit(1)