init scale 0.7x

This commit is contained in:
autoresearch
2026-03-08 10:14:06 +00:00
parent 41d50a8539
commit f5979a7464
+1 -1
View File
@@ -152,7 +152,7 @@ class GPT(nn.Module):
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks
n_embd = self.config.n_embd
s = 0.8 * 3**0.5 * n_embd**-0.5
s = 0.7 * 3**0.5 * n_embd**-0.5
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)