Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32a1460f62 | |||
| 513fe6fcee | |||
| c2450add72 | |||
| 0be1e4fdf9 | |||
| ebf357841b | |||
| 09ebea439d | |||
| c12eef778e | |||
| b5ba8ac00d | |||
| 068d93da75 | |||
| c92bee55eb | |||
| 2224cd7cae | |||
| f16ece488f | |||
| 9264224a3c |
@@ -18,3 +18,6 @@ AGENTS.md
|
||||
|
||||
# Experimental code/artifacts
|
||||
dev/
|
||||
|
||||
# Results file
|
||||
results.tsv
|
||||
|
||||
@@ -8,7 +8,7 @@ The idea: give an AI agent a small but real LLM training setup and let it experi
|
||||
|
||||
## How it works
|
||||
|
||||
The repo is deliberately kept small and only really has a three files that matter:
|
||||
The repo is deliberately kept small and only really has three files that matter:
|
||||
|
||||
- **`prepare.py`** — fixed constants, one-time data prep (downloads training data, trains a BPE tokenizer), and runtime utilities (dataloader, evaluation). Not modified.
|
||||
- **`train.py`** — the single file the agent edits. Contains the full GPT model, optimizer (Muon + AdamW), and training loop. Everything is fair game: architecture, hyperparameters, optimizer, batch size, etc. **This file is edited and iterated on by the agent**.
|
||||
@@ -16,6 +16,8 @@ The repo is deliberately kept small and only really has a three files that matte
|
||||
|
||||
By design, training runs for a **fixed 5-minute time budget** (wall clock, excluding startup/compilation), regardless of the details of your compute. The metric is **val_bpb** (validation bits per byte) — lower is better, and vocab-size-independent so architectural changes are fairly compared.
|
||||
|
||||
If you are new to neural networks, this ["Dummy's Guide"](https://x.com/hooeem/status/2030720614752039185) looks pretty good for a lot more context.
|
||||
|
||||
## Quick start
|
||||
|
||||
**Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/).
|
||||
@@ -37,8 +39,6 @@ uv run train.py
|
||||
|
||||
If the above commands all work ok, your setup is working and you can go into autonomous research mode.
|
||||
|
||||
**Platforms support**. This code currently requires that you have a single NVIDIA GPU. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. The code is just a demonstration and I don't know how much I'll support it going forward. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc.
|
||||
|
||||
## Running the agent
|
||||
|
||||
Simply spin up your Claude/Codex or whatever you want in this repo (and disable all permissions), then you can prompt something like:
|
||||
@@ -64,9 +64,28 @@ pyproject.toml — dependencies
|
||||
- **Fixed time budget.** Training always runs for exactly 5 minutes, regardless of your specific platform. This means you can expect approx 12 experiments/hour and approx 100 experiments while you sleep. There are two upsides of this design decision. First, this makes experiments directly comparable regardless of what the agent changes (model size, batch size, architecture, etc). Second, this means that autoresearch will find the most optimal model for your platform in that time budget. The downside is that your runs (and results) become not comparable to other people running on other compute platforms.
|
||||
- **Self-contained.** No external dependencies beyond PyTorch and a few small packages. No distributed training, no complex configs. One GPU, one file, one metric.
|
||||
|
||||
## Platform support
|
||||
|
||||
This code currently requires that you have a single NVIDIA GPU. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc.
|
||||
|
||||
Seeing as there seems to be a lot of interest in tinkering with autoresearch on much smaller compute platforms than an H100, a few extra words. If you're going to try running autoresearch on smaller computers (Macbooks etc.), I'd recommend one of the forks below. On top of this, here are some recommendations for how to tune the defaults for much smaller models for aspiring forks:
|
||||
|
||||
1. To get half-decent results I'd use a dataset with a lot less entropy, e.g. this [TinyStories dataset](https://huggingface.co/datasets/karpathy/tinystories-gpt4-clean). These are GPT-4 generated short stories. Because the data is a lot narrower in scope, you will see reasonable results with a lot smaller models (if you try to sample from them after training).
|
||||
2. You might experiment with decreasing `vocab_size`, e.g. from 8192 down to 4096, 2048, 1024, or even - simply byte-level tokenizer with 256 possibly bytes after utf-8 encoding.
|
||||
3. In `prepare.py`, you'll want to lower `MAX_SEQ_LEN` a lot, depending on the computer even down to 256 etc. As you lower `MAX_SEQ_LEN`, you may want to experiment with increasing `DEVICE_BATCH_SIZE` in `train.py` slightly to compensate. The number of tokens per fwd/bwd pass is the product of these two.
|
||||
4. Also in `prepare.py`, you'll want to decrease `EVAL_TOKENS` so that your validation loss is evaluated on a lot less data.
|
||||
5. In `train.py`, the primary single knob that controls model complexity is the `DEPTH` (default 8, here). A lot of variables are just functions of this, so e.g. lower it down to e.g. 4.
|
||||
6. You'll want to most likely use `WINDOW_PATTERN` of just "L", because "SSSL" uses alternating banded attention pattern that may be very inefficient for you. Try it.
|
||||
7. You'll want to lower `TOTAL_BATCH_SIZE` a lot, but keep it powers of 2, e.g. down to `2**14` (~16K) or so even, hard to tell.
|
||||
|
||||
I think these would be the reasonable hyperparameters to play with. Ask your favorite coding agent for help and copy paste them this guide, as well as the full source code.
|
||||
|
||||
## Notable forks
|
||||
|
||||
- [miolini/autoresearch-macos](https://github.com/miolini/autoresearch-macos)
|
||||
- [miolini/autoresearch-macos](https://github.com/miolini/autoresearch-macos) (MacOS)
|
||||
- [trevin-creator/autoresearch-mlx](https://github.com/trevin-creator/autoresearch-mlx) (MacOS)
|
||||
- [jsegov/autoresearch-win-rtx](https://github.com/jsegov/autoresearch-win-rtx) (Windows)
|
||||
- [andyluo7/autoresearch](https://github.com/andyluo7/autoresearch) (AMD)
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -258,6 +258,7 @@ def _document_batches(split, tokenizer_batch_size=128):
|
||||
val_path = os.path.join(DATA_DIR, VAL_FILENAME)
|
||||
if split == "train":
|
||||
parquet_paths = [p for p in parquet_paths if p != val_path]
|
||||
assert len(parquet_paths) > 0, "No training shards found."
|
||||
else:
|
||||
parquet_paths = [val_path]
|
||||
epoch = 1
|
||||
|
||||
+2
-2
@@ -13,7 +13,7 @@ To set up a new experiment, work with the user to:
|
||||
- `prepare.py` — fixed constants, data prep, tokenizer, dataloader, evaluation. Do not modify.
|
||||
- `train.py` — the file you modify. Model architecture, optimizer, training loop.
|
||||
4. **Verify data exists**: Check that `~/.cache/autoresearch/` contains data shards and a tokenizer. If not, tell the human to run `uv run prepare.py`.
|
||||
5. **Initialize results.tsv**: Create `results.tsv` with header row and baseline entry. The baseline results are already known from the output format section below (val_bpb: 0.997900, peak_vram_mb: 45060.2). Do NOT re-run the baseline — just record it.
|
||||
5. **Initialize results.tsv**: Create `results.tsv` with just the header row. The baseline will be recorded after the first run.
|
||||
6. **Confirm and go**: Confirm setup looks good.
|
||||
|
||||
Once you get confirmation, kick off the experimentation.
|
||||
@@ -99,7 +99,7 @@ LOOP FOREVER:
|
||||
4. Run the experiment: `uv run train.py > run.log 2>&1` (redirect everything — do NOT use tee or let output flood your context)
|
||||
5. Read out the results: `grep "^val_bpb:\|^peak_vram_mb:" run.log`
|
||||
6. If the grep output is empty, the run crashed. Run `tail -n 50 run.log` to read the Python stack trace and attempt a fix. If you can't get things to work after more than a few attempts, give up.
|
||||
7. Record the results in the tsv
|
||||
7. Record the results in the tsv (NOTE: do not commit the results.tsv file, leave it untracked by git)
|
||||
8. If val_bpb improved (lower), you "advance" the branch, keeping the git commit
|
||||
9. If val_bpb is equal or worse, you git reset back to where you started
|
||||
|
||||
|
||||
-127
@@ -1,127 +0,0 @@
|
||||
commit val_bpb memory_gb status description
|
||||
baseline 0.997900 44.0 keep baseline
|
||||
bea057b 0.986041 43.9 keep halve batch 524K to 262K (more steps in 5 min)
|
||||
7f2a65c 0.981773 60.2 keep depth 9 aspect_ratio 57 (extra layer dim ~512)
|
||||
187e419 0.982603 60.2 discard add 5% warmup
|
||||
4e6697f 0.981201 60.2 keep warmdown 0.5 to 0.7
|
||||
8363d52 0.980903 60.2 keep SSSSL window pattern (5:1 short:long)
|
||||
7da0b67 0.979969 60.2 keep short window 1/8 context (256 tokens)
|
||||
59e9dd9 0.978784 60.2 keep RoPE base frequency 10K to 200K
|
||||
7d047e4 0.975524 60.2 keep embedding LR 0.6 to 0.8
|
||||
c4ce95c 0.975895 60.2 discard unembedding LR 0.004 to 0.008
|
||||
0640555 0.974729 60.2 keep x0_lambda init 0.1 to 0.05
|
||||
772dada 0.974119 60.2 keep FINAL_LR_FRAC 0.0 to 0.05
|
||||
ccf6012 0.974903 60.2 discard matrix LR 0.04 to 0.045
|
||||
aa8f408 0.973104 60.2 keep unembedding LR 0.004 to 0.006
|
||||
889dbed 0.973799 60.2 discard random seed 42 to 137
|
||||
e05a87d 0.000000 0.0 crash batch 131K (assert fail: not divisible by device batch)
|
||||
0dc6130 0.974134 60.2 discard embedding LR 0.8 to 1.0
|
||||
fa14910 0.973824 60.2 discard softcap 15 to 20
|
||||
187db84 0.973659 60.2 discard warmdown 0.7 to 0.8
|
||||
3ec9cfa 0.979340 53.7 discard depth 10 aspect 51 dim 512 (too narrow)
|
||||
2913af9 0.973177 60.2 discard weight decay 0.2 to 0.15
|
||||
a7aa309 0.972849 60.2 keep muon momentum warmup 300 to 200 steps
|
||||
6b6b241 0.973385 60.2 discard VE gate channels 32 to 48
|
||||
77d7a47 0.973121 60.2 discard scalar LR 0.5 to 0.3
|
||||
d6cad11 0.973130 60.2 discard Adam beta1 0.8 to 0.85
|
||||
b19f649 0.978313 60.2 discard remove cautious WD mask (much worse)
|
||||
9aa1e29 0.974490 60.2 discard FINAL_LR_FRAC 0.05 to 0.1
|
||||
aa61e0b 0.973658 60.2 discard gradient clipping max_norm=1.0
|
||||
31b838e 0.973706 60.2 discard Muon ns_steps 5 to 4
|
||||
ebff004 0.973076 60.2 discard LR scale reference 768 to 640
|
||||
d3c7143 0.973339 60.2 discard muon final momentum 0.95 to 0.96
|
||||
a6b2ac7 0.973119 60.2 discard Muon beta2 0.95 to 0.90
|
||||
6b59591 0.972821 60.2 discard VE gate scale 2 to 3 (flat)
|
||||
d41c1df 0.976114 60.2 discard resid lambda init 1.0 to 0.9
|
||||
97aa364 0.973828 60.2 discard matrix LR 0.04 to 0.035
|
||||
5db6ed4 0.979735 59.5 discard VE only last 3 layers (much worse)
|
||||
8189d27 0.973894 60.2 discard embedding LR 0.8 to 0.9
|
||||
7f63c17 0.972779 60.2 keep unembedding LR 0.006 to 0.005
|
||||
d0662d1 0.974038 60.2 discard RoPE base 200K to 400K
|
||||
d3840ec 0.974356 60.2 discard constant WD at 0.1 (decaying better)
|
||||
264a05b 0.972694 60.2 keep add WD 0.01 to lm_head
|
||||
7d3f0e4 0.972847 60.2 discard softcap 15 to 13
|
||||
00a7c09 0.979754 72.6 discard depth 11 dim 640 (too big, too few steps)
|
||||
674a510 0.975033 60.2 discard add WD 0.01 to embeddings (hurts)
|
||||
b1f02f7 0.975328 60.2 discard add 2% warmup (any warmup hurts)
|
||||
81261f5 0.973469 60.2 discard halve value embedding LR
|
||||
51f0499 0.972844 60.2 discard x0_lambda beta1 0.96 to 0.90
|
||||
de48b64 0.974912 60.2 discard SSSL pattern (more long layers hurt steps)
|
||||
01a3c69 0.973105 60.2 discard FINAL_LR_FRAC 0.05 to 0.02
|
||||
86c7e66 0.974639 60.2 discard lm_head init std 0.001 to 0.01
|
||||
489bb99 0.976462 60.2 discard x0_lambda init 0.0 (x0 skip important)
|
||||
a16391b 0.973059 60.2 discard rotary precompute 10x to 2x
|
||||
8dd93ec 0.972712 60.2 discard VE LR 1.5x (flat)
|
||||
802d184 0.974123 60.2 discard embedding init std 1.0 to 2.0
|
||||
2b9a688 0.974331 60.2 discard sqrt WD schedule
|
||||
ffcb3c2 0.972982 60.2 discard muon start momentum 0.85 to 0.80
|
||||
3cde993 0.974655 66.2 discard depth 10 same dim 640 (too few steps)
|
||||
8be9036 0.975285 53.9 discard depth 8 dim 640 (too shallow)
|
||||
2271cc2 0.974190 60.2 discard WD follows LR schedule
|
||||
46cf5f2 0.983719 54.6 discard parallel attn+MLP (much worse)
|
||||
59316b9 0.973312 60.2 discard warmdown 0.7 to 0.65
|
||||
c4b0731 0.973803 57.4 discard MLP hidden 4x to 3.5x
|
||||
6193116 0.973173 60.2 discard RoPE base 200K to 150K
|
||||
c1f79a6 0.973005 60.2 discard FINAL_LR_FRAC 0.05 to 0.03
|
||||
ee60bf7 0.976203 60.2 discard SSSSSL pattern (too few long layers)
|
||||
a7b953a 0.973088 60.2 discard lm_head WD 0.01 to 0.05
|
||||
41d50a8 0.972258 60.2 keep reduce transformer init scale by 0.8x
|
||||
991abb2 0.972721 60.2 discard init scale 0.6x (0.8 better)
|
||||
f5979a7 0.972128 60.2 keep init scale 0.7x
|
||||
2216fd6 0.973025 60.2 discard init scale 0.65x (0.7 better)
|
||||
ddcd35a 0.972587 60.2 discard embedding init std 1.0 to 0.7
|
||||
8934eec 0.972776 60.2 discard lm_head init std 0.001 to 0.002
|
||||
92b4765 0.973847 60.2 discard small random init for c_proj (worse)
|
||||
d385aa7 0.972901 60.2 discard scalar LR 0.5 to 0.7
|
||||
db37d12 0.973155 60.2 discard unembedding LR 0.005 to 0.004
|
||||
f04daec 0.973155 60.2 discard weight decay 0.2 to 0.25
|
||||
d931c3a 0.975790 60.2 discard x0_lambda init 0.05 to 0.04 (worse)
|
||||
c5a4645 0.972216 60.2 discard VE init scale 0.5x of transformer init
|
||||
30f1b8d 0.973361 60.2 discard cosine warmdown schedule (linear better)
|
||||
5a9c951 0.972877 63.1 discard MLP hidden 4x to 4.5x (fewer steps)
|
||||
ab8f970 0.975964 60.2 discard decreasing resid_lambda init (hurts)
|
||||
2a3f587 0.972901 60.2 discard softcap 15 to 14
|
||||
362937e 0.972495 60.2 discard VE gate channels 32 to 16
|
||||
0d77d4d 0.972621 60.2 discard Adam beta2 0.95 to 0.99
|
||||
4eebd43 0.973493 60.2 discard x0_lambda LR 2x
|
||||
b85567f 0.979987 52.0 discard multi-query attention n_kv_head=1 (too few KV heads)
|
||||
0da44e6 0.973545 60.2 discard small nonzero init for c_proj (zero better)
|
||||
d6c139a 0.973831 60.2 discard embedding init std 1.0 to 0.5
|
||||
d70987b 3.215849 60.2 discard weight tying (shared embed/unembed, broken)
|
||||
bff5cda 0.975852 59.5 discard VE every 3rd layer (too few VEs)
|
||||
5953d58 0.973423 60.2 discard WD constant until warmdown then decay
|
||||
d1eb994 0.974314 60.2 discard smaller QK init 0.5x (uniform init matters for Muon)
|
||||
3c19fba 0.974046 60.2 discard depth-dependent init scale 1/sqrt(layer+1)
|
||||
119065a 0.972335 60.2 discard init scale 0.7 to 0.72
|
||||
97dda85 0.972097 60.2 keep init scale 0.7 to 0.68
|
||||
58b8b7a 0.972350 60.2 discard init scale 0.68 to 0.66 (0.68 better)
|
||||
70c2737 0.972731 60.2 discard Muon NorMuon beta2 0.95 to 0.98
|
||||
8232e01 0.973000 60.2 discard resid_lambda LR 0.01x to 0.04x
|
||||
21389c4 0.973723 60.2 discard Adam beta1 0.8 to 0.9
|
||||
e4c0f3e 0.974043 60.2 discard short window 1/6 context (slower)
|
||||
2e2a2f8 0.972632 60.2 discard short window 1/10 context (quality loss)
|
||||
9db7b86 0.972744 60.2 discard lm_head init std 0.001 to 0.0005
|
||||
ece9101 0.972009 60.2 keep tiny embedding WD 0.001
|
||||
b07c56b 0.972438 60.2 discard embedding WD 0.001 to 0.002
|
||||
1a85362 0.971058 60.2 keep tiny VE WD 0.001
|
||||
73c77ca 0.970655 60.2 keep VE WD 0.001 to 0.002
|
||||
637f82f 0.970433 60.2 keep VE WD 0.002 to 0.003
|
||||
c152812 0.970644 60.2 discard VE WD 0.003 to 0.005 (0.003 better)
|
||||
efd2171 0.970703 60.2 discard embedding WD 0.001 to 0.002
|
||||
328de7c 0.970612 60.2 discard lm_head WD 0.01 to 0.02
|
||||
c0c2349 0.970758 60.2 discard lm_head WD 0.01 to 0.005
|
||||
b1d5004 0.969952 60.2 keep embedding LR 0.8 to 0.9 (with WD)
|
||||
2ca8872 0.970767 60.2 discard embedding LR 0.9 to 1.0
|
||||
74a3b33 0.970759 60.2 discard unembedding LR 0.005 to 0.006
|
||||
d1f68da 0.970106 60.2 discard embedding WD 0.001 to 0.002 (with LR 0.9)
|
||||
ebbe8c0 0.971004 60.2 discard matrix LR 0.04 to 0.045
|
||||
b9ee7d6 0.970040 60.2 discard VE WD 0.003 to 0.004
|
||||
2f0a8ec 0.970573 60.2 discard Muon WD 0.2 to 0.22
|
||||
438a26e 0.969686 60.2 keep warmdown 0.7 to 0.75
|
||||
d9322b9 0.970244 60.2 discard warmdown 0.75 to 0.8
|
||||
8876cf3 0.969714 60.2 discard FINAL_LR_FRAC 0.05 to 0.03
|
||||
80330e2 0.970135 60.2 discard x0_lambda init 0.05 to 0.06
|
||||
2f0cec6 0.970678 60.2 discard RoPE base 200K to 300K
|
||||
c044a14 0.970212 60.2 discard VE gate scale 2 to 3
|
||||
80a519a 0.969857 60.2 discard VE LR 1.5x with WD
|
||||
a6b6476 0.970286 60.2 discard muon momentum warmup 200 to 150 steps
|
||||
|
@@ -9,6 +9,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
@@ -152,7 +153,7 @@ class GPT(nn.Module):
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
# Transformer blocks
|
||||
n_embd = self.config.n_embd
|
||||
s = 0.68 * 3**0.5 * n_embd**-0.5
|
||||
s = 3**0.5 * n_embd**-0.5
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
@@ -162,7 +163,7 @@ class GPT(nn.Module):
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
# Per-layer scalars
|
||||
self.resid_lambdas.fill_(1.0)
|
||||
self.x0_lambdas.fill_(0.05)
|
||||
self.x0_lambdas.fill_(0.1)
|
||||
# Value embeddings
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
@@ -179,7 +180,7 @@ class GPT(nn.Module):
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None):
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
||||
@@ -195,7 +196,7 @@ class GPT(nn.Module):
|
||||
pattern = config.window_pattern.upper()
|
||||
assert all(c in "SL" for c in pattern)
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 8
|
||||
short_window = long_window // 2
|
||||
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
|
||||
window_sizes = []
|
||||
for layer_idx in range(config.n_layer):
|
||||
@@ -247,9 +248,9 @@ class GPT(nn.Module):
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
param_groups = [
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.001),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.003),
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
@@ -429,24 +430,24 @@ class MuonAdamW(torch.optim.Optimizer):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Model architecture
|
||||
ASPECT_RATIO = 57 # model_dim = depth * ASPECT_RATIO
|
||||
ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
|
||||
HEAD_DIM = 128 # target head dimension for attention
|
||||
WINDOW_PATTERN = "SSSSL" # sliding window pattern: L=full, S=half context
|
||||
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
|
||||
|
||||
# Optimization
|
||||
TOTAL_BATCH_SIZE = 2**18 # ~262K tokens per optimizer step
|
||||
EMBEDDING_LR = 0.9 # learning rate for token embeddings (Adam)
|
||||
UNEMBEDDING_LR = 0.005 # learning rate for lm_head (Adam)
|
||||
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
|
||||
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
|
||||
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
|
||||
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
|
||||
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
|
||||
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
|
||||
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
|
||||
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
|
||||
WARMDOWN_RATIO = 0.75 # fraction of time budget for LR warmdown
|
||||
FINAL_LR_FRAC = 0.05 # final LR as fraction of initial
|
||||
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
|
||||
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
|
||||
|
||||
# Model size
|
||||
DEPTH = 9 # number of transformer layers
|
||||
DEPTH = 8 # number of transformer layers
|
||||
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -524,7 +525,7 @@ def get_lr_multiplier(progress):
|
||||
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
|
||||
|
||||
def get_muon_momentum(step):
|
||||
frac = min(step / 200, 1)
|
||||
frac = min(step / 300, 1)
|
||||
return (1 - frac) * 0.85 + frac * 0.95
|
||||
|
||||
def get_weight_decay(progress):
|
||||
@@ -565,8 +566,8 @@ while True:
|
||||
|
||||
train_loss_f = train_loss.item()
|
||||
|
||||
# Fast fail: abort if loss is exploding
|
||||
if train_loss_f > 100:
|
||||
# Fast fail: abort if loss is exploding or NaN
|
||||
if math.isnan(train_loss_f) or train_loss_f > 100:
|
||||
print("FAIL")
|
||||
exit(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user