diff --git a/train.py b/train.py index 11d27e7..fdef2cb 100644 --- a/train.py +++ b/train.py @@ -179,7 +179,7 @@ class GPT(nn.Module): for ve in self.value_embeds.values(): ve.to(dtype=torch.bfloat16) - def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): + def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None): if device is None: device = self.transformer.wte.weight.device channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)