Commit 7c12230
committed
feat: unify edge_degree + layer radial MLPs into single batched GEMM
UnifiedRadialMLP consolidates edge_degree_embedding.rad_func and all layer
rad_funcs into a single first-layer GEMM, reducing kernel launches and
improving GPU utilization.
Key changes:
- UnifiedRadialMLP: batches first linear layer, processes tails separately
- get_unified_radial_emb: returns [edge_degree_out, layer_0_out, ...]
- rad_func=None sentinel: signals precomputed radials in EdgeDegreeEmbedding
- Fast backends (UMASFastPytorchBackend, UMASFastGPUBackend) create and use
UnifiedRadialMLP at prepare_model_for_inference time
Also includes torch.compile compatibility fixes:
- ChgSpinEmbedding: replaced dict lookup with tensor arithmetic
- balance_channels: minor cleanup for compile compatibility
Performance: 17.4 QPS on 2000 atoms (H200), forces match baseline.1 parent 331203d commit 7c12230
5 files changed
Lines changed: 355 additions & 201 deletions
File tree
- src/fairchem/core/models/uma
- nn
- tests/core/models/uma/nn
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
727 | 727 | | |
728 | 728 | | |
729 | 729 | | |
| 730 | + | |
| 731 | + | |
730 | 732 | | |
731 | 733 | | |
732 | 734 | | |
| |||
736 | 738 | | |
737 | 739 | | |
738 | 740 | | |
739 | | - | |
740 | | - | |
741 | | - | |
742 | | - | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
743 | 744 | | |
744 | 745 | | |
745 | 746 | | |
| |||
772 | 773 | | |
773 | 774 | | |
774 | 775 | | |
| 776 | + | |
| 777 | + | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
| 782 | + | |
| 783 | + | |
| 784 | + | |
| 785 | + | |
775 | 786 | | |
776 | 787 | | |
777 | | - | |
| 788 | + | |
778 | 789 | | |
779 | 790 | | |
780 | 791 | | |
| |||
784 | 795 | | |
785 | 796 | | |
786 | 797 | | |
787 | | - | |
788 | | - | |
789 | | - | |
790 | | - | |
791 | | - | |
792 | | - | |
793 | 798 | | |
794 | 799 | | |
795 | 800 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
82 | 82 | | |
83 | 83 | | |
84 | 84 | | |
85 | | - | |
| 85 | + | |
| 86 | + | |
86 | 87 | | |
87 | 88 | | |
88 | 89 | | |
| |||
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | 154 | | |
161 | 155 | | |
162 | 156 | | |
| |||
173 | 167 | | |
174 | 168 | | |
175 | 169 | | |
176 | | - | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
177 | 180 | | |
178 | 181 | | |
179 | 182 | | |
| |||
199 | 202 | | |
200 | 203 | | |
201 | 204 | | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
209 | 211 | | |
210 | 212 | | |
211 | 213 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
91 | | - | |
| 91 | + | |
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
95 | 95 | | |
96 | | - | |
| 96 | + | |
97 | 97 | | |
98 | | - | |
99 | | - | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
100 | 101 | | |
101 | | - | |
| 102 | + | |
102 | 103 | | |
103 | 104 | | |
104 | 105 | | |
105 | 106 | | |
106 | 107 | | |
107 | 108 | | |
108 | | - | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
109 | 112 | | |
110 | | - | |
| 113 | + | |
111 | 114 | | |
112 | 115 | | |
113 | 116 | | |
| |||
242 | 245 | | |
243 | 246 | | |
244 | 247 | | |
245 | | - | |
| 248 | + | |
246 | 249 | | |
247 | 250 | | |
248 | 251 | | |
| |||
252 | 255 | | |
253 | 256 | | |
254 | 257 | | |
255 | | - | |
| 258 | + | |
256 | 259 | | |
257 | 260 | | |
258 | 261 | | |
| |||
291 | 294 | | |
292 | 295 | | |
293 | 296 | | |
294 | | - | |
| 297 | + | |
| 298 | + | |
295 | 299 | | |
296 | 300 | | |
297 | 301 | | |
| |||
302 | 306 | | |
303 | 307 | | |
304 | 308 | | |
305 | | - | |
306 | | - | |
307 | | - | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
308 | 320 | | |
309 | 321 | | |
310 | | - | |
| 322 | + | |
311 | 323 | | |
312 | 324 | | |
313 | 325 | | |
314 | 326 | | |
315 | | - | |
| 327 | + | |
316 | 328 | | |
317 | 329 | | |
318 | 330 | | |
319 | 331 | | |
320 | 332 | | |
321 | 333 | | |
322 | | - | |
| 334 | + | |
323 | 335 | | |
324 | 336 | | |
325 | 337 | | |
| |||
408 | 420 | | |
409 | 421 | | |
410 | 422 | | |
411 | | - | |
| 423 | + | |
412 | 424 | | |
413 | 425 | | |
414 | 426 | | |
415 | 427 | | |
416 | 428 | | |
417 | 429 | | |
418 | 430 | | |
419 | | - | |
| 431 | + | |
420 | 432 | | |
421 | 433 | | |
422 | 434 | | |
| |||
0 commit comments