Skip to content

Commit bdf55a2

Browse files
sirmarcelclaude
andcommitted
Expand README training docs, default to float32 matmul precision
- Document full training workflow (data prep, config, running) - Explain $DATASETS environment variable - Add collapsible reference table for all training settings - Default matmul precision to float32 instead of default Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a14365a commit bdf55a2

2 files changed

Lines changed: 101 additions & 4 deletions

File tree

README.md

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,111 @@ calc = Calculator.from_checkpoint("path/to/checkpoint")
3939

4040
### Training
4141

42-
Training uses `lorem-train`, which reads `model.yaml` and `settings.yaml` from the current directory:
42+
Training a model involves three steps: preparing the data, configuring the model and training settings, and running the training script.
43+
44+
#### 1. Prepare data
45+
46+
Training data is stored in [marathon](https://github.com/sirmarcel/marathon) format. Convert your extended XYZ dataset using a preparation script (see `examples/train-mlp/prepare.py` for a template):
47+
48+
```python
49+
from marathon.data import datasets, get_splits
50+
from marathon.grain import prepare
51+
52+
# datasets is a Path resolved from the $DATASETS environment variable
53+
prepare(train_atoms, folder=datasets / "my_project/train", ...)
54+
prepare(valid_atoms, folder=datasets / "my_project/valid", ...)
55+
```
56+
57+
The `$DATASETS` environment variable sets the root directory where prepared datasets are stored. All dataset paths in `settings.yaml` are resolved relative to this directory.
58+
59+
#### 2. Configure the experiment
60+
61+
Each experiment lives in its own directory containing two YAML files:
62+
63+
**`model.yaml`** defines the model architecture:
64+
65+
```yaml
66+
model:
67+
lorem.Lorem:
68+
cutoff: 5.0
69+
max_degree: 4
70+
max_degree_lr: 2
71+
num_features: 128
72+
num_spherical_features: 4
73+
num_message_passing: 1
74+
```
75+
76+
Use `lorem.LoremBEC` instead of `lorem.Lorem` to train a model that additionally predicts Born effective charges.
77+
78+
**`settings.yaml`** configures training:
79+
80+
```yaml
81+
train: "my_project/train" # path relative to $DATASETS
82+
valid: "my_project/valid" # path relative to $DATASETS
83+
seed: 23
84+
batcher:
85+
batch_size: 4
86+
loss_weights: {"energy": 0.5, "forces": 0.5}
87+
optimizer: adam # adam or muon
88+
start_learning_rate: 1e-3
89+
min_learning_rate: 1e-6
90+
max_epochs: 2000
91+
valid_every_epoch: 2
92+
decay_style: linear # linear, exponential, or warmup_cosine
93+
use_wandb: True
94+
```
95+
96+
<details>
97+
<summary>All training settings</summary>
98+
99+
| Setting | Default | Description |
100+
|---|---|---|
101+
| `train` | *required* | Training dataset path (relative to `$DATASETS`) |
102+
| `valid` | *required* | Validation dataset path (relative to `$DATASETS`) |
103+
| `test_datasets` | `{}` | Extra test datasets: `{name: [path, save_predictions]}` |
104+
| `batcher.batch_size` | *required* | Samples per batch |
105+
| `batcher.size_strategy` | `powers_of_4` | Padding strategy for batch dimensions |
106+
| `loss_weights` | `{"energy": 0.5, "forces": 0.5}` | Per-target loss weights |
107+
| `scale_by_variance` | `False` | Scale loss weights by validation set variance |
108+
| `optimizer` | `adam` | Optimizer (`adam`, `muon`, or any optax optimizer) |
109+
| `start_learning_rate` | `1e-3` | Initial learning rate |
110+
| `min_learning_rate` | `1e-6` | Minimum learning rate |
111+
| `max_epochs` | `2000` | Maximum training epochs |
112+
| `valid_every_epoch` | `2` | Validate every N epochs |
113+
| `decay_style` | `linear` | LR schedule: `linear`, `exponential`, or `warmup_cosine` |
114+
| `start_decay_after` | `10` | Epoch to begin LR decay |
115+
| `stop_decay_after` | `max_epochs` | Epoch to end LR decay (linear only) |
116+
| `warmup_epochs` | `0` | Warmup epochs (`warmup_cosine` only) |
117+
| `gradient_clip` | `0` | Gradient clipping threshold (0 = disabled) |
118+
| `seed` | `0` | Random seed |
119+
| `rotational_augmentation` | `False` | Apply random rotations to training data |
120+
| `filter_mixed_pbc` | `False` | Filter out structures with mixed periodic boundary conditions |
121+
| `filter_above_num_atoms` | `False` | Filter out structures above this atom count |
122+
| `checkpointers` | `default` | `default` or `full` (adds RMSE checkpointers) |
123+
| `use_wandb` | `True` | Log to Weights & Biases |
124+
| `wandb_project` | auto | W&B project name (default: derived from folder names) |
125+
| `wandb_name` | auto | W&B run name (default: experiment folder name) |
126+
| `benchmark_pipeline` | `True` | Benchmark data pipeline before training |
127+
| `compilation_cache` | `False` | Enable JAX persistent compilation cache |
128+
| `default_matmul_precision` | `float32` | JAX matmul precision (`default`, `float32`) |
129+
| `debug_nans` | `False` | Enable JAX NaN debugging (~50% slowdown) |
130+
| `enable_x64` | `False` | Enable 64-bit floating point |
131+
| `worker_count` | `4` | Data loading workers (training) |
132+
| `worker_count_valid` | `worker_count` | Data loading workers (validation) |
133+
| `worker_buffer_size` | `2` | Prefetch buffer per worker (training) |
134+
135+
</details>
136+
137+
#### 3. Run training
43138

44139
```bash
45140
cd my_experiment
46-
lorem-train
141+
DATASETS=/path/to/datasets lorem-train
47142
```
48143

49-
See `examples/train-mlp/` for a complete example including data preparation and configuration files.
144+
Training writes checkpoints, logs, and plots to a `run/` directory inside the experiment folder. If a `run/` directory already exists, training resumes from the latest checkpoint.
145+
146+
See `examples/train-mlp/` and `examples/train-bec/` for complete examples including data preparation and configuration files.
50147

51148
### Model variants
52149

src/lorem/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
wandb_project = settings.pop("wandb_project", None)
7474
wandb_name = settings.pop("wandb_name", None)
7575

76-
default_matmul_precision = settings.pop("default_matmul_precision", "default")
76+
default_matmul_precision = settings.pop("default_matmul_precision", "float32")
7777
debug_nans = settings.pop("debug_nans", False) # ~50% slowdown, use with care
7878
enable_x64 = settings.pop("enable_x64", False)
7979

0 commit comments

Comments
 (0)