Commit 29416f6
feat: improving flow and score matching API and nets (#1544)
* shift VectorFieldNet type from estimators/flowmatching_estimator.py to utils/vector_field_utils.py
* refactor: update imports and enhance ConditionalScoreEstimator
- changed net type in ConditionalScoreEstimator and related classes to VectorFieldNet
- added embedding_net to ConditionalScoreEstimator for condition embedding
* Update z-score parameters in flowmatching and posterior score neural networks, keeping others as independent
* refactor: update vector field neural network architecture
- replaced flowmatcher and score estimator imports with vector field equivalents
- introduced new vector field neural network builders for MLP and transformer architectures
- enhanced flowmatching_nn and posterior_score_nn functions to support new model types
- added custom euler integration method in FlowMatchingEstimator for improved sampling
- updated z-score handling and embedding net integration in estimator classes
Remaining bugs:
- Extra num_samples dim in Zuko sampling function that we need to fix
* update: fix handling of shapes during sampling for npse and fmpe
- increased hidden_features from 50 to 64 in posterior_score_nn to make it divisible by num_heads
- added vector_field_fn to FlowMatchingEstimator for better dimension handling during sampling
- implemented reshaping logic in ConditionalScoreEstimator to manage sample dimensions during batch processing
* refactor: improve docstring formatting and remove commented-out code
* refactor: integrate embedding net in flowmatching estimator
- replaced direct condition input with embedded condition in log_prob, sample, and sample_and_log_prob methods
- removed commented-out code and print statements for cleaner implementation
- updated score estimator import to use posterior_score_nn in tests/score_samplers_test.py
* refactor: update flowmatching estimator and vector field network initialization
- removed noise_scale parameter from FlowMatchingEstimator
- adjusted vector field calculation in FlowMatchingEstimator for improved accuracy
- modified last layer initialization in AdaMLPBlock and DiTBlock to scale weights instead of zeroing
- streamlined MLP block processing in GlobalEmbeddingMLP for clarity
* Fix some problems
* To run CI, comment out broken tests
* Updates: Internal nets should be shared, but Estimator builders should be seperate! (as they have different preconditionrs)
* Unify shape handling in score and flow.
Add tests for consistentency of score and flow estimators
* Formating to get CI going (failing tests expected)
* Some small fixes and refactorings
* Fix ruff things
* Fixing score sampler tests with new net builder API
* Fixing flow estimator bugs
* Bug hunting + fixing
* Rearrange trainers + fixing tests to not use "special" hyperparameters to test (i.e .use defaults).
* Fix ruff
* Fixing failing tests
* Fix validation loss check
* Fix for new FMPE args
* Consistent naming for FMPE and NPSE
* Bug fix for Neural ODE sampling with ScoreEstimators
* less numsims for vfestimators in tests
* Test changes
* Bug fix, bounded epochs on default??? Add a better convergence chechk...
* Remove print
* new mlp which performs better...
* Allow setting num_sims in minisbibm for eval
* Add arg for num sims, cache results by default
* some refactorings
* Fix formating
* Formatting, make defaults more uniform
* Make factories more SBI-like
* Some estimator tests added
* Formatting, fix kwargs errors, more tests
* Fix tests and init transformer last layer as zero
* Remove test jupyter :/
* Formatting, refactoring tests
* Fix pyright
* Remove what is expected to fail
* Minor fixes
* Small docstring update
* Backward compatiblity warnings from some unused kwargs
* Typing with vectorfield net
* Simplify score estimator
* Updates
* Fixing transformer with cross attn
* Add error msg for unsupported shapes
* Better tests
* refactored tests
* Reverting wierd reshapings in score estimator.
Removing code duplication on embedding net handing
* Fix formating issues
* Fixing inconsistencies
* Fixing pyright
* Fix embedding_net not passed
* Fix embedding net bug
* Remove redundant "num_blocks"
* Adding some degree of backward compatibility on user interface.
* Fixing failing test on new convergence check
* Add transformer to bm
* Must be okay that the files already exits bm
* Fix merge bug. Add deprecation warnings for Score estimator keyword argument in NPSE
* Fixing transformers... (no pos emb. and others)
* Refactorings and tunings
* deprecation warnings and small refactorings
* Backwards compatibility
* Move score_estimator tests to vf_estimator_tests, run doc notebook once
* remove random wierd comment
* Remove tolerance special cases
* Consistent naming
* Faster convergence for slighly worse performance
* Backward compatibility for imports of NPSE and FMPE
* Docstring update
* Imporve docstrings
* Backward compatibility
* Use new keywords
* Format
* Add missing headers
* Update sbi/inference/trainers/vfpe/base_vf_inference.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/fmpe/__init__.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/npse/__init__.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/vfpe/fmpe.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/vfpe/fmpe.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/vfpe/npse.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/inference/trainers/vfpe/npse.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/__init__.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/estimators/flowmatching_estimator.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/estimators/score_estimator.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/factory.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/factory.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/net_builders/vector_field_nets.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/net_builders/vector_field_nets.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update sbi/neural_nets/factory.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update tests/bm_test.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update tests/bm_test.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update tests/bm_test.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Update tests/bm_test.py
Co-authored-by: Jan <janfb@users.noreply.github.com>
* Add nugget as keyward argument to train
* Imporve converged docstring
* Better typing and docstrings and so on
* docstring
* Add context
* Extended docstring
* move protocol
* Formating fix
* Revert "move protocol"
This reverts commit 7eeaa7d.
* fix formating
* removing deprecated
* Fix typing
* Positional argument for default builder model name
* Formating
* unify nets test
* fixing builder
* update notebooks
* Formating and some text updates
* formating
* fix deprecation warning on default args
* unnecessary
* remove unecessary notes
* refactor check for deprecation warning
* fix mcmc params passing in test
* Fix mnle_test
* add missing import
* Notebooks rerun without warning and with striped notebook outputs
---------
Co-authored-by: Jaivardhan Kapoor <jaivardhan.kapoor@gmail.com>
Co-authored-by: Jan <janfb@users.noreply.github.com>
Co-authored-by: Jan <jan.boelts@mailbox.org>1 parent 7064286 commit 29416f6
File tree
30 files changed
+2680
-1182
lines changed- docs/advanced_tutorials
- sbi
- inference
- trainers
- npse
- vfpe
- neural_nets
- estimators
- net_builders
- utils
- tests
30 files changed
+2680
-1182
lines changedLines changed: 285 additions & 15 deletions
Large diffs are not rendered by default.
Lines changed: 98 additions & 23 deletions
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | | - | |
11 | 10 | | |
12 | 11 | | |
13 | 12 | | |
14 | | - | |
15 | 13 | | |
| 14 | + | |
16 | 15 | | |
17 | 16 | | |
18 | 17 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
673 | 673 | | |
674 | 674 | | |
675 | 675 | | |
676 | | - | |
677 | | - | |
678 | | - | |
679 | | - | |
680 | | - | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
681 | 684 | | |
682 | | - | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
683 | 689 | | |
684 | 690 | | |
685 | 691 | | |
| |||
This file was deleted.
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
| 3 | + | |
| 4 | + | |
Lines changed: 93 additions & 8 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
| 7 | + | |
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
64 | 67 | | |
65 | 68 | | |
66 | 69 | | |
| |||
106 | 109 | | |
107 | 110 | | |
108 | 111 | | |
109 | | - | |
| 112 | + | |
110 | 113 | | |
111 | 114 | | |
112 | 115 | | |
113 | 116 | | |
114 | 117 | | |
115 | 118 | | |
116 | 119 | | |
117 | | - | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
118 | 125 | | |
119 | 126 | | |
120 | 127 | | |
| |||
209 | 216 | | |
210 | 217 | | |
211 | 218 | | |
212 | | - | |
213 | | - | |
| 219 | + | |
| 220 | + | |
214 | 221 | | |
215 | 222 | | |
216 | 223 | | |
217 | | - | |
| 224 | + | |
| 225 | + | |
218 | 226 | | |
219 | 227 | | |
220 | 228 | | |
| |||
253 | 261 | | |
254 | 262 | | |
255 | 263 | | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
256 | 267 | | |
257 | 268 | | |
258 | 269 | | |
| |||
341 | 352 | | |
342 | 353 | | |
343 | 354 | | |
344 | | - | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
345 | 358 | | |
346 | 359 | | |
347 | 360 | | |
| |||
431 | 444 | | |
432 | 445 | | |
433 | 446 | | |
| 447 | + | |
| 448 | + | |
434 | 449 | | |
435 | 450 | | |
436 | 451 | | |
| |||
486 | 501 | | |
487 | 502 | | |
488 | 503 | | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
489 | 574 | | |
490 | 575 | | |
491 | 576 | | |
| |||
Lines changed: 32 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
13 | | - | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
17 | | - | |
18 | | - | |
| 18 | + | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
27 | | - | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
28 | 33 | | |
29 | 34 | | |
30 | 35 | | |
| |||
35 | 40 | | |
36 | 41 | | |
37 | 42 | | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
42 | 48 | | |
| 49 | + | |
| 50 | + | |
43 | 51 | | |
44 | 52 | | |
45 | 53 | | |
| |||
48 | 56 | | |
49 | 57 | | |
50 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
51 | 68 | | |
52 | 69 | | |
53 | 70 | | |
54 | 71 | | |
55 | 72 | | |
56 | 73 | | |
57 | | - | |
| 74 | + | |
58 | 75 | | |
59 | 76 | | |
60 | | - | |
61 | | - | |
62 | 77 | | |
63 | 78 | | |
64 | 79 | | |
| |||
106 | 121 | | |
107 | 122 | | |
108 | 123 | | |
109 | | - | |
110 | | - | |
111 | | - | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
0 commit comments