init scale 0.7 to 0.68
This commit is contained in:
@@ -152,7 +152,7 @@ class GPT(nn.Module):
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
# Transformer blocks
|
||||
n_embd = self.config.n_embd
|
||||
s = 0.7 * 3**0.5 * n_embd**-0.5
|
||||
s = 0.68 * 3**0.5 * n_embd**-0.5
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
|
||||
Reference in New Issue
Block a user