Commit 1888205
refactoring noise_schedule and time schedule into base class (#1736)
* refactoring noise_schedule and time schedule into base class
- created noise_schedule method to be overwritten by derivatives
- created times_schedules method to be overwritten by derivatives
- created test on times_schedules
- improvded docstrings
* added noise schedule test
* implemented beta schedule for variance-preserving estimators
- added tests too
- inspired by https://arxiv.org/abs/2206.00364
* code cosmetics triggered by ruff
* cloned VPScoreEstimator to yield improved version
in addition:
- added improved version to benchmarks (for later comparison)
- created new class ImprovedVPScoreEstimator
* more realistic bounds for unit test
* typo and refactoring
to understand how the VE estimator is implemented
* fixed wrong setup of pmean and pstd
* code reformatting
* use the time schedule for computing the validation scores
* propagate name change
* fix unit tests to respect new schedules
* comply with formatting
* attempted to implement EDM-like diffusion
- without touching the forward function of ConditionalScoreEstimator
- benchmarks show that this leads to very long training time
without any performance improvements
* removed "improved" denoising network
* consolidated tests
* removed occurrances of vp++
* removed all mentions of edm
* ruff fixes
* WIP : use time schedule in loss function, address device issues
* call solve_schedule in validation step
* add solve_schedule method, call train_schedule in loss
* call the solve schedule during sampling with SDE
* add a solve_schedule function in the conditional vf estimator class to unify training in vftrainer class
* make the solve schedule deterministic
* corrections on solve schedule
* WIP : create solve schedule in base class
* modify arguments of solve schedule
* change train_schedule + docstrings fixes + device handling
* include validation times nugget to avoid instabilities during training
* change the nb of simulations for ve option
* change device in solve schedule
* add noise schedule in VE subclass
* reshape noise schedule output in VE class
* reshape noise schedule output in VE class
* add tests on train and solve schedule shapes, devices, bounds
* formatting and changing tests
---------
Co-authored-by: Peter Steinbach <p.steinbach@hzdr.de>
Co-authored-by: Camille Touron <ctouron@ptb-07008323.grenoble.inria.fr>1 parent 937efc2 commit 1888205
File tree
6 files changed
+208
-56
lines changed- sbi
- inference
- posteriors
- trainers/vfpe
- neural_nets/estimators
- tests
6 files changed
+208
-56
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
313 | 313 | | |
314 | 314 | | |
315 | 315 | | |
316 | | - | |
317 | | - | |
| 316 | + | |
| 317 | + | |
318 | 318 | | |
319 | 319 | | |
320 | 320 | | |
| |||
340 | 340 | | |
341 | 341 | | |
342 | 342 | | |
343 | | - | |
344 | 343 | | |
345 | | - | |
346 | | - | |
347 | | - | |
| 344 | + | |
348 | 345 | | |
349 | 346 | | |
350 | 347 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
313 | 313 | | |
314 | 314 | | |
315 | 315 | | |
316 | | - | |
317 | | - | |
318 | | - | |
| 316 | + | |
319 | 317 | | |
| 318 | + | |
| 319 | + | |
320 | 320 | | |
321 | 321 | | |
322 | 322 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
479 | 479 | | |
480 | 480 | | |
481 | 481 | | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
482 | 511 | | |
483 | 512 | | |
484 | 513 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
| 69 | + | |
| 70 | + | |
69 | 71 | | |
70 | 72 | | |
71 | 73 | | |
| |||
111 | 113 | | |
112 | 114 | | |
113 | 115 | | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
114 | 120 | | |
115 | 121 | | |
116 | 122 | | |
| |||
228 | 234 | | |
229 | 235 | | |
230 | 236 | | |
231 | | - | |
| 237 | + | |
| 238 | + | |
232 | 239 | | |
233 | 240 | | |
234 | 241 | | |
| |||
240 | 247 | | |
241 | 248 | | |
242 | 249 | | |
243 | | - | |
| 250 | + | |
244 | 251 | | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
250 | 255 | | |
251 | 256 | | |
252 | 257 | | |
| |||
390 | 395 | | |
391 | 396 | | |
392 | 397 | | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
393 | 478 | | |
394 | 479 | | |
395 | 480 | | |
| |||
480 | 565 | | |
481 | 566 | | |
482 | 567 | | |
483 | | - | |
484 | | - | |
485 | 568 | | |
486 | 569 | | |
487 | 570 | | |
| |||
490 | 573 | | |
491 | 574 | | |
492 | 575 | | |
| 576 | + | |
| 577 | + | |
493 | 578 | | |
494 | 579 | | |
495 | 580 | | |
| |||
525 | 610 | | |
526 | 611 | | |
527 | 612 | | |
528 | | - | |
529 | | - | |
530 | | - | |
531 | | - | |
532 | | - | |
533 | | - | |
534 | | - | |
535 | | - | |
536 | | - | |
537 | | - | |
538 | | - | |
539 | 613 | | |
540 | 614 | | |
541 | 615 | | |
| |||
546 | 620 | | |
547 | 621 | | |
548 | 622 | | |
549 | | - | |
| 623 | + | |
550 | 624 | | |
551 | 625 | | |
552 | 626 | | |
| |||
561 | 635 | | |
562 | 636 | | |
563 | 637 | | |
564 | | - | |
| 638 | + | |
565 | 639 | | |
566 | 640 | | |
567 | 641 | | |
| |||
604 | 678 | | |
605 | 679 | | |
606 | 680 | | |
607 | | - | |
608 | | - | |
609 | 681 | | |
610 | 682 | | |
611 | 683 | | |
612 | 684 | | |
613 | 685 | | |
614 | 686 | | |
| 687 | + | |
| 688 | + | |
615 | 689 | | |
616 | 690 | | |
617 | 691 | | |
| |||
649 | 723 | | |
650 | 724 | | |
651 | 725 | | |
652 | | - | |
653 | | - | |
654 | | - | |
655 | | - | |
656 | | - | |
657 | | - | |
658 | | - | |
659 | | - | |
660 | | - | |
661 | | - | |
662 | | - | |
663 | | - | |
664 | 726 | | |
665 | 727 | | |
666 | 728 | | |
| |||
671 | 733 | | |
672 | 734 | | |
673 | 735 | | |
674 | | - | |
| 736 | + | |
675 | 737 | | |
676 | 738 | | |
677 | 739 | | |
| |||
690 | 752 | | |
691 | 753 | | |
692 | 754 | | |
693 | | - | |
| 755 | + | |
694 | 756 | | |
695 | 757 | | |
696 | 758 | | |
| |||
788 | 850 | | |
789 | 851 | | |
790 | 852 | | |
791 | | - | |
792 | | - | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
793 | 856 | | |
794 | 857 | | |
795 | 858 | | |
796 | 859 | | |
797 | 860 | | |
798 | | - | |
| 861 | + | |
799 | 862 | | |
800 | | - | |
| 863 | + | |
| 864 | + | |
801 | 865 | | |
802 | 866 | | |
803 | 867 | | |
| |||
821 | 885 | | |
822 | 886 | | |
823 | 887 | | |
824 | | - | |
825 | | - | |
826 | | - | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
827 | 891 | | |
828 | 892 | | |
829 | 893 | | |
830 | | - | |
831 | 894 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
| 72 | + | |
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
| |||
0 commit comments