Skip to content
Merged
Show file tree
Hide file tree
Changes from 241 commits
Commits
Show all changes
254 commits
Select commit Hold shift + click to select a range
f753573
Add missing init files
coreyjadams Nov 3, 2025
2ef835e
Update build system and specify some deps.
coreyjadams Nov 3, 2025
1603067
Merge branch 'main' into refactor
coreyjadams Nov 3, 2025
1e8df52
Reorganize tests.
coreyjadams Nov 3, 2025
2e1195c
Update init files
coreyjadams Nov 3, 2025
a698685
Clean up neighbor tools.
coreyjadams Nov 3, 2025
258d988
Update testing
coreyjadams Nov 3, 2025
0638b97
Fix compat tests
coreyjadams Nov 3, 2025
b6327cb
Move core model tests to tests/core/
coreyjadams Nov 3, 2025
3ce049a
Add import lint config
coreyjadams Nov 3, 2025
95fa450
Relocate layers
coreyjadams Nov 3, 2025
ba6813d
Move graphcast utils into model directory
coreyjadams Nov 3, 2025
3f10463
Relocating util functionalities.
coreyjadams Nov 4, 2025
339b484
Further clean up and organize tests.
coreyjadams Nov 5, 2025
18df402
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 5, 2025
d6946d9
utils tests are passing now
coreyjadams Nov 5, 2025
66f8d15
Cleaning up distributed tests
coreyjadams Nov 5, 2025
2ee76db
Patching tests working again in nn
coreyjadams Nov 5, 2025
33d525d
Fix sdf test
coreyjadams Nov 5, 2025
a06ad0a
Fix zenith angle tests
coreyjadams Nov 5, 2025
4c845cc
Some organization of tests. Checkpoints is moved into utils.
coreyjadams Nov 5, 2025
3bb64f4
Remove launch.utils and launch.config. Checkpointing is moved to
coreyjadams Nov 5, 2025
4aa332e
Most nn tests are passing
coreyjadams Nov 5, 2025
45686cc
Further cleanup. Getting there!
coreyjadams Nov 5, 2025
bbc54f6
Remove constants file
coreyjadams Nov 5, 2025
8453fea
Add import linting to pre-commit.
coreyjadams Nov 5, 2025
7ff2a2a
Refactor (#1208)
coreyjadams Nov 5, 2025
f850488
Merge branch 'main' into refactor
coreyjadams Nov 5, 2025
1c5f91c
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 5, 2025
21343f5
Unmigrate the insolation utils (#1211)
pzharrington Nov 6, 2025
337c91e
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 7, 2025
4583c42
Move gnn layers and start to fix several model tests.
coreyjadams Nov 7, 2025
e326d4a
AFNO is now passing.
coreyjadams Nov 7, 2025
b95097d
Rnn models passing.
coreyjadams Nov 7, 2025
d8bc6f9
Fix improt
coreyjadams Nov 7, 2025
314f1b2
Healpix tests are working
coreyjadams Nov 7, 2025
9c7d287
Domino and unet working
coreyjadams Nov 7, 2025
0012209
Refactor (#1216)
coreyjadams Nov 7, 2025
32e1dce
Update activations path in dlwp tests (#1217)
pzharrington Nov 7, 2025
afa903f
Updating to address some test issues
coreyjadams Nov 10, 2025
91ceb0a
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 10, 2025
f9130a6
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 10, 2025
ceb1eb8
Merge branch 'main' into refactor
coreyjadams Nov 10, 2025
0592d80
MGN tests passing again
coreyjadams Nov 10, 2025
857b3db
Most graphcast tests passing again
coreyjadams Nov 10, 2025
f89a2fb
Move nd conv layers.
coreyjadams Nov 10, 2025
409200d
update fengwu and pangu
coreyjadams Nov 10, 2025
14b51fd
Update sfno and pix2pix test
coreyjadams Nov 10, 2025
27fd304
update tests for figconvnet, swinrnn, superresnet
coreyjadams Nov 10, 2025
0d22d11
updating more models to pass
coreyjadams Nov 10, 2025
60ba0ce
Update distributed tests, now passing.
coreyjadams Nov 10, 2025
7ec2251
Domain parallel tests now passing.
coreyjadams Nov 11, 2025
d9fe7a4
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 12, 2025
af9e359
Fix active learning imports so tests pass in refactor
coreyjadams Nov 12, 2025
e3b7849
Fix some metric imports
coreyjadams Nov 12, 2025
b1f2ef9
Remove deploy package
coreyjadams Nov 12, 2025
f46ff8c
Remove unused test file
coreyjadams Nov 12, 2025
edd2224
unmigrate these files ... again?
coreyjadams Nov 12, 2025
1c769e3
Update import linter.
coreyjadams Nov 12, 2025
b9aa3dd
Refactor (#1224)
coreyjadams Nov 12, 2025
8d8255a
Merge branch 'main' into refactor
coreyjadams Nov 12, 2025
8b266b0
Cleaning up diffusion models. Not quite done yet.
coreyjadams Nov 12, 2025
8a8a05a
Merge branch 'main' into refactor
coreyjadams Nov 12, 2025
9b0d40d
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 13, 2025
ff0aacf
Restore deleted files
coreyjadams Nov 13, 2025
f11fcd7
Updating more tests.
coreyjadams Nov 13, 2025
9e32712
Further updates to tests. Datapipes almost working.
coreyjadams Nov 14, 2025
4fe41b9
Refactor (#1231)
coreyjadams Nov 14, 2025
0b78d6c
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 17, 2025
ac1fcef
update import paths
coreyjadams Nov 17, 2025
d81ee43
Starting to clean up dependency tree.
coreyjadams Nov 18, 2025
dff27b3
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 18, 2025
8a0a3a5
Refactor (#1233)
coreyjadams Nov 18, 2025
3cb9a02
Added coding standards for model implementations as a custom context …
CharlelieLrt Nov 18, 2025
d7bcd0d
Fixing and adjusting a broad suite of tests.
coreyjadams Nov 19, 2025
d32879d
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 19, 2025
c4ef437
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 19, 2025
b3b7786
Update test/domain_parallel/conftest.py
coreyjadams Nov 19, 2025
af41fdf
Minor fix
coreyjadams Nov 19, 2025
611a029
Refactor (#1234)
coreyjadams Nov 19, 2025
58c909c
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 19, 2025
17ff6de
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 19, 2025
e83ea99
Not seeing any errors in testing ...
coreyjadams Nov 19, 2025
ec163e1
Breakdown of rules into smaller rules (#1236)
CharlelieLrt Nov 19, 2025
15a04f1
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 20, 2025
42e4b40
Refactor (#1240)
coreyjadams Nov 20, 2025
ff8ddac
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 20, 2025
51c0ccb
Merge branch 'main' into refactor
coreyjadams Nov 24, 2025
e16f9f2
Refactor (#1247)
coreyjadams Nov 24, 2025
60ccc72
Enable import linting on internal imports.
coreyjadams Nov 24, 2025
9b62b7d
Remove ensure_available function, it's confusing
coreyjadams Nov 24, 2025
f05150f
Add logging imports to utils, and fix imports in examples.
coreyjadams Nov 24, 2025
64d731f
Update imports in minimal examples
coreyjadams Nov 24, 2025
725ecfe
Update structural mechanics examples
coreyjadams Nov 24, 2025
d8e5f05
Update import paths: reservoir_sim
coreyjadams Nov 24, 2025
666be4b
Update import paths: additive manufacturing
coreyjadams Nov 24, 2025
19b8afd
Update import paths: topodiff
coreyjadams Nov 24, 2025
824c76a
Update import paths: weather part 1
coreyjadams Nov 24, 2025
641c110
Update import paths: weather part 2
coreyjadams Nov 24, 2025
2e056db
Update import paths: molecular dynamics
coreyjadams Nov 24, 2025
6a9f6e6
Update import paths: geophysics
coreyjadams Nov 24, 2025
b874e4e
Update import paths: cfd + external_aero 1
coreyjadams Nov 24, 2025
23f2955
Update import paths: cfd + external_aero 2
coreyjadams Nov 24, 2025
581c79a
Remove more DGL examples
coreyjadams Nov 24, 2025
6d780d7
Remove more DGL examples
coreyjadams Nov 24, 2025
7763d96
cfd examples 3
coreyjadams Nov 24, 2025
53fa1cb
Last batch of example import fixes!
coreyjadams Nov 24, 2025
1cd3ada
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 24, 2025
5fdcf0f
Enforce and protect external deps in utils.
coreyjadams Nov 25, 2025
b5842e3
Remove DGL. :party:
coreyjadams Nov 25, 2025
da742e7
Don't force models yet
coreyjadams Nov 25, 2025
6c872a0
Refactor (#1249)
coreyjadams Nov 25, 2025
363126a
Automated model registry (#1252)
CharlelieLrt Nov 26, 2025
76a29ef
Metadata name deprecation (#1257)
CharlelieLrt Nov 26, 2025
942c375
Merge main into local refactor
coreyjadams Dec 1, 2025
8d8939d
Refactor (#1258)
coreyjadams Dec 1, 2025
cbc2dd3
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 1, 2025
170efa7
Merge branch 'v2.0-refactor' into refactor
coreyjadams Dec 1, 2025
8898450
Remove IPDB
coreyjadams Dec 1, 2025
8aa8dd9
Few more dep fixes.
coreyjadams Dec 1, 2025
70d9135
Merge branch 'main' into refactor
coreyjadams Dec 2, 2025
ec69852
Refactor (#1261)
coreyjadams Dec 2, 2025
17788b0
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 2, 2025
3c03b08
Add external import coding standards.
coreyjadams Dec 2, 2025
a842398
Update external import standards.
coreyjadams Dec 3, 2025
dae0942
Ensure vtk functions are protected.
coreyjadams Dec 3, 2025
042f7ea
Protect pyvista import
coreyjadams Dec 3, 2025
5bb0e6f
Closing more import gaps
coreyjadams Dec 3, 2025
d35d5c7
Remove DGL from meshgraphkan
coreyjadams Dec 3, 2025
12b98d8
All models now comply with external import linting.
coreyjadams Dec 3, 2025
a879e8d
Remove DGL datapipes
coreyjadams Dec 3, 2025
b200b50
cae datapipes in compliance
coreyjadams Dec 3, 2025
cb1766c
Update pyproject.toml
coreyjadams Dec 3, 2025
d339e1f
Add version numbers to deps
coreyjadams Dec 3, 2025
aad176c
Refactor (#1261)
coreyjadams Dec 3, 2025
6c9cebd
Merge branch 'refactor' into v2.0-refactor
coreyjadams Dec 3, 2025
7422e4c
fix import error from wandb
coreyjadams Dec 3, 2025
75490ea
remove instance check
coreyjadams Dec 3, 2025
ddf6ea9
Initial restructure
CharlelieLrt Dec 5, 2025
8e634f9
Completed restructure of diffusion package
CharlelieLrt Dec 5, 2025
1f66eb6
UV <---> Pip must stay in sync. (#1264)
coreyjadams Dec 8, 2025
ab46322
Fix broken imports
coreyjadams Dec 8, 2025
c8c4da6
Fix README links in transolver and domino examples (#1259)
dran-dev Dec 4, 2025
9132858
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 8, 2025
770589b
Add xarray, timm to core deps
coreyjadams Dec 8, 2025
e07fbd2
update import
coreyjadams Dec 9, 2025
6f92470
Somehow, a number of import protections got broken
coreyjadams Dec 9, 2025
de56395
Automatically select CPU or CPU+CUDA instead of decorating every test.
coreyjadams Dec 10, 2025
281e90c
ensure te installed for serialization test
coreyjadams Dec 10, 2025
289c11d
All CPU tests are passing
coreyjadams Dec 10, 2025
d07eab2
Remove DGL/PyG equivalency tests (#1273)
Alexey-Kamenev Dec 10, 2025
c6d6525
Install ci (#1274)
coreyjadams Dec 12, 2025
ea5ab3a
Remove TensorFlow dependency in Vortex Shedding and Lagrangian MGN ex…
Alexey-Kamenev Dec 12, 2025
18f5872
Change registry behavior and list all models as entry points (#1278)
CharlelieLrt Dec 12, 2025
1c2fa2f
Renamed LearnedSimulator into VGFNLearnedSimulator
CharlelieLrt Dec 13, 2025
c2464d3
Fix tests + improve docs for new register arg in from_torch
CharlelieLrt Dec 13, 2025
6305bb8
Remove physicsnemo.model.Module remaining items
coreyjadams Dec 15, 2025
d1ca859
Remove incorrect meta import
coreyjadams Dec 15, 2025
7753a55
Remove incorrect comment
coreyjadams Dec 15, 2025
8c7c08a
Fix linting errors
coreyjadams Dec 16, 2025
c4b71ea
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
8e55518
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
28d1871
Fixing some linting errors
coreyjadams Dec 16, 2025
35235fe
More linter errors
coreyjadams Dec 16, 2025
7f10726
One more.
coreyjadams Dec 16, 2025
b770b23
Update knn tests
coreyjadams Dec 16, 2025
5667675
Purge pylib cugraphops
coreyjadams Dec 16, 2025
ec6e35b
Remove more cugraphops paths.
coreyjadams Dec 16, 2025
e028a59
Trying to close some CI errors.
coreyjadams Dec 16, 2025
d83168c
Fixing more CI issues
coreyjadams Dec 16, 2025
5bf2fb8
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
b371bba
Fix MGN tests (#1281)
Alexey-Kamenev Dec 16, 2025
b860fcd
Fix apex issues on CPU with a diffusion-specific device fixture.
coreyjadams Dec 16, 2025
bfb9d43
Fixing shard tensor import; adjusting pytorch geometric import point …
coreyjadams Dec 16, 2025
b577342
Fixing more imports.
coreyjadams Dec 16, 2025
85c56e0
fix one or two more
coreyjadams Dec 16, 2025
cd0fb4c
Merge branch 'v2.0-refactor' into restructure-diffusion-subpackage
CharlelieLrt Dec 16, 2025
496b17e
Fix MGK, HMGN tests (#1282)
Alexey-Kamenev Dec 17, 2025
f390748
Fix import error
coreyjadams Dec 17, 2025
6a9f5cb
Remove cugraphops
coreyjadams Dec 17, 2025
76c6bd5
Fix many tests
coreyjadams Dec 17, 2025
6aa5c22
Add migration guide early draft. Update external imports.
coreyjadams Dec 17, 2025
020d928
Attempting to fix the last failing tests.
coreyjadams Dec 17, 2025
23ae40e
Add pre-commit action. (#1286)
coreyjadams Dec 17, 2025
d5fc130
Tweak the CI install and testing of imports / docstrings
coreyjadams Dec 17, 2025
8441672
Wow, the tests were not tied to ANY timezone. It only passes in UTC....
coreyjadams Dec 17, 2025
05f839b
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 17, 2025
3a74e58
fix all but 2 docstring tests
coreyjadams Dec 17, 2025
25f8b56
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 18, 2025
dc03aab
Resolve circular import + fix linting errors.
coreyjadams Dec 17, 2025
bdded65
Fixed broken Group Norm
CharlelieLrt Dec 18, 2025
f431779
Merge branch 'v2.0-refactor' into restructure-diffusion-subpackage
CharlelieLrt Dec 18, 2025
ea2314e
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Dec 19, 2025
ea3c105
Added diffusion.generate
CharlelieLrt Dec 19, 2025
5bb7a49
Added future feature and deprecation warnings for diffusion module
CharlelieLrt Dec 19, 2025
7388bd3
Defined import-linter contracts for physicsnemo.diffusion
CharlelieLrt Dec 19, 2025
2bea331
Updated PR template with missing item
CharlelieLrt Dec 19, 2025
b3518d0
Added missing diffusion.generate
CharlelieLrt Dec 19, 2025
dae4fbc
Fixed a few remaining paths physicsnemo.models.diffusion that does no…
CharlelieLrt Dec 19, 2025
1969383
CI tests fixes
CharlelieLrt Dec 19, 2025
b452e5b
mmiranda nvidia style guide Updates diffusion.rst
megnvidia Dec 19, 2025
55fee4b
mmiranda smol style guide Updates physicsnemo.utils.rst
megnvidia Dec 19, 2025
cdd92cf
Fixed checklist in PR template
CharlelieLrt Dec 19, 2025
071a7e3
Deleted comment in .importlinter
CharlelieLrt Dec 19, 2025
707e0d2
Fixed references in diffusion.rst
CharlelieLrt Dec 19, 2025
0bd74ef
Merge branch 'restructure-diffusion-subpackage' of https://github.com…
CharlelieLrt Dec 19, 2025
4aa1f7b
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Jan 5, 2026
fca12d1
Fix checkpoint loading with Module subclass when known
CharlelieLrt Jan 6, 2026
8f38564
Deleted physicsnemo/compat
CharlelieLrt Jan 6, 2026
08e89f2
Deleted useless comments in flow_reconstruction_diffusion example
CharlelieLrt Jan 6, 2026
df3ad90
Renamed Attantion into UNetAttention
CharlelieLrt Jan 6, 2026
205aa97
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Jan 6, 2026
6ff908e
Implemented BasePreconditioner
CharlelieLrt Jan 7, 2026
f08a686
Improvements to BaseConditioner docs
CharlelieLrt Jan 7, 2026
dc733a9
Implemented new preconditioners based on BasePerconditioner
CharlelieLrt Jan 7, 2026
fa02d48
Migrated legacy preconditioners to reuse new preconditioners
CharlelieLrt Jan 7, 2026
f7b8494
Initial implementation of tests for preconditioners
CharlelieLrt Jan 8, 2026
047ca44
Added reference data for non-regression CI tests of preconditioners
CharlelieLrt Jan 8, 2026
2866c10
Improvements to preconditioners CI tests
CharlelieLrt Jan 8, 2026
5f9a309
Adedd a few details in BasePreconditioner doctrsing
CharlelieLrt Jan 8, 2026
5af0aff
Merge branch 'main' into diffusion-preconditioners-refactor
CharlelieLrt Jan 8, 2026
17ee57e
Updated CHANGELOG.md
CharlelieLrt Jan 8, 2026
1e2d2a7
Improved documentation of signature requirement in BasePreconditioner
CharlelieLrt Jan 8, 2026
ffcb026
Renamed BasePreconditioner into BaseAffinePreconditioner
CharlelieLrt Jan 9, 2026
5d5c66a
Added DiffusionModel protocol to specify diffusion models signature
CharlelieLrt Jan 9, 2026
993db63
Changed condition argument to TensorDict instead of Dict of tensors
CharlelieLrt Jan 9, 2026
22dac49
Moved all preconditioners scalar attributes to pytorch buffers instea…
CharlelieLrt Jan 9, 2026
9c5f53b
Improvements to make precondtioners tests more robust on GPU
CharlelieLrt Jan 10, 2026
2a0f125
Initial implementation of diffusion sampler
CharlelieLrt Jan 10, 2026
a40e449
Some updates to samplers and solvers
CharlelieLrt Jan 22, 2026
91b27fa
Some progress
CharlelieLrt Jan 26, 2026
752561e
Mostly completed implementation of sampling utilities
CharlelieLrt Jan 30, 2026
d6e0343
Added tEDM and VP noise schedulers
CharlelieLrt Jan 30, 2026
e7d095c
Merge branch 'main' into diffusion-samplers-and-guidance
CharlelieLrt Jan 31, 2026
1f59c4c
Changed str to Literal
CharlelieLrt Feb 3, 2026
82565b6
Replaced scale s(t) with alpha in stochastic solvers
CharlelieLrt Feb 4, 2026
0c50b19
Removed inheritance in student-t EDM noise scheduler
CharlelieLrt Feb 6, 2026
e9f9586
Addressed PR comments
CharlelieLrt Feb 14, 2026
ac28793
Refactored protocols Denoiser and Predictor
CharlelieLrt Feb 14, 2026
58aac76
Refactored get_denoiser to use keyword arguments for the input predictor
CharlelieLrt Feb 14, 2026
556b2c8
Merge branch 'main' into diffusion-samplers-and-guidance
CharlelieLrt Feb 14, 2026
c786692
Fix license header
CharlelieLrt Feb 17, 2026
79db00f
Fixed docstring example in samplers.py
CharlelieLrt Feb 17, 2026
365c7d3
Revert "Fixed docstring example in samplers.py"
CharlelieLrt Feb 17, 2026
3acfafa
Fixed docstring example in samplers.py
CharlelieLrt Feb 17, 2026
386f968
Merge branch 'main' into diffusion-samplers-and-guidance
CharlelieLrt Feb 17, 2026
3cbc736
Changed alpha_fn to diffusion_fn in solvers.py
CharlelieLrt Feb 18, 2026
947c3d4
Fixed code-blocks and missing jaxtyping
CharlelieLrt Feb 18, 2026
94e956b
Temporarily omit physicsnemo.diffusion from coverage
CharlelieLrt Feb 18, 2026
2b6b6cb
Merge branch 'main' into diffusion-samplers-and-guidance
CharlelieLrt Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion physicsnemo/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import DiffusionModel # noqa: F401
from .base import DiffusionDenoiser, DiffusionModel # noqa: F401
97 changes: 97 additions & 0 deletions physicsnemo/diffusion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,100 @@ def __call__(
Model output with the same shape as ``x``.
"""
...


@runtime_checkable
class DiffusionDenoiser(Protocol):
r"""
Protocol defining a denoiser interface for diffusion model inference.

This is the minimal interface required for sampling from a diffusion model,
and any object that implements this interface can be used as a denoiser in
sampling utilities.

A denoiser is a callable that takes a noisy state ``x`` and a noise level
(or diffusion time) ``t``, and typically returns a denoising term. A
denoising could be the right hand side for
ODE/SDE-based sampling, directly the denoised latent state for discrete
Markov chain-based sampling, etc. This interface is generic and it does not
make any assumption about the nature of the denoising term. It is expected
to be used in conjunction with a compatible
:class:`~physicsnemo.diffusion.samplers.solvers.Solver` to perform the
actual sampling.

This protocol is used during inference. For training,
which often requires additional inputs like conditioning, use the more
general :class:`DiffusionModel` protocol instead.

A :class:`DiffusionDenoiser` can be obtained from a
:class:`DiffusionModel` by partially applying the ``condition`` and
any other keyword arguments. For example:

.. code-block:: python

from functools import partial
from tensordict import TensorDict

class MyDiffusionModel:
def __call__(self, x, t, condition):
# Model forward pass using x, t, and condition
return x * 0.9 + condition["y"]

model = MyDiffusionModel()
my_condition = TensorDict({"y": torch.randn(batch_size, 10)})

# Create denoiser by partially applying the condition
denoiser = partial(model, condition=my_condition)

# Now denoiser(x, t) implements the DiffusionDenoiser interface.

See Also
--------
:func:`~physicsnemo.diffusion.samplers.sample` : The sampling function
that uses this denoiser interface.
:class:`DiffusionModel` : The full diffusion model interface with
conditioning support.

Examples
--------
>>> import torch
>>> from physicsnemo.diffusion import DiffusionDenoiser
>>>
>>> class SimpleDenoiser:
... def __call__(self, x, t):
... # A trivial denoiser that returns the input unchanged
... return x
...
>>> denoiser = SimpleDenoiser()
>>> isinstance(denoiser, DiffusionDenoiser)
True
"""

def __call__(
self,
x: Float[torch.Tensor, "B *dims"], # noqa: F821
t: Float[torch.Tensor, "B "], # noqa: F821
) -> Float[torch.Tensor, "B *dims"]: # noqa: F821
r"""
Function to produce a denoising output at the given noise level.

Parameters
----------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional
dimensions (e.g., channels and spatial dimensions).
t : torch.Tensor
Batched diffusion time tensor of shape :math:`(B,)`.
All batch elements in the latent state ``x`` typically share the
same diffusion time values, but ``t`` is still required to be a
batched tensor.

Returns
-------
torch.Tensor
Denoising output. Can be the right hand side for ODE/SDE-based
sampling, the denoised latent state for discrete Markov chain-based
sampling, etc.
"""
...
17 changes: 8 additions & 9 deletions physicsnemo/diffusion/noise_schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from physicsnemo.core.warnings import FutureFeatureWarning

warnings.warn(
"The 'physicsnemo.diffusion.noise_schedulers' module is a placeholder for "
"future functionality that will be implemented in an upcoming release.",
FutureFeatureWarning,
stacklevel=2,
from .noise_schedulers import ( # noqa: F401
EDMNoiseScheduler,
IDDPMNoiseScheduler,
LinearGaussianNoiseScheduler,
NoiseScheduler,
StudentTEDMNoiseScheduler,
VENoiseScheduler,
VPNoiseScheduler,
)
Loading
Loading