-
Notifications
You must be signed in to change notification settings - Fork 473
Open
Labels
! - ReleasePRs or Issues releating to a releasePRs or Issues releating to a releaseDocumentationImprovements or additions to documentationImprovements or additions to documentationbugSomething isn't workingSomething isn't working
Description
Version
1.2.0
On which installation method(s) does this occur?
No response
Describe the issue
I defined a module in a training script, providing a name as an argument to the metadata. This model cannot be loaded from another script even if I make sure to import the train module.
I think I am confused about what provided information is used to register a new module.
Is it the name field physicsnemo.ModelMetaData? This was what I expected.
Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?
Minimum reproducible example
# train.py
from dataclasses import dataclass
import physicsnemo
import torch.nn as nn
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
# modified from https://docs.nvidia.com/physicsnemo/latest/physicsnemo/api/models/modules.html#how-to-write-your-own-physicsnemo-model
class UNetExample(physicsnemo.Module):
def __init__(self, in_channels=1, out_channels=1, outc=None):
super().__init__(meta=physicsnemo.ModelMetaData(name="UnetExample"))
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.dec1 = self.upconv_block(128, 64)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x = self.dec1(x2)
return self.final(x)
if __name__ == "__main__":
# normal save works
unet = UNetExample()
Loading script
$ cat load_module.py
import physicsnemo
import train # imported so that model can register itself
import sys
physicsnemo.Module.from_checkpoint(sys.argv[1])Relevant log output
$ python3 train.py
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
$ python3 load_module.py out.mdlus
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
Model {'__name__': 'UNetExample', '__module__': '__main__', '__args__': {'in_channels': 1, 'out_channels': 1, 'outc': None}}
Traceback (most recent call last):
File "/lustre/fs1/portfolios/coreai/projects/coreai_climate_earth2/nbrenowitz/repos/edm-chaos/load_module.py", line 5, in <module>
physicsnemo.Module.from_checkpoint(sys.argv[1])
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 517, in from_checkpoint
model = Module.instantiate(args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 278, in instantiate
return _cls(**arg_dict["__args__"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 74, in __new__
bound_args = sig.bind_partial(
^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 3249, in bind_partial
return self._bind(args, kwargs, partial=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 3231, in _bind
raise TypeError(
TypeError: got an unexpected keyword argument 'in_channels'Environment details
Metadata
Metadata
Assignees
Labels
! - ReleasePRs or Issues releating to a releasePRs or Issues releating to a releaseDocumentationImprovements or additions to documentationImprovements or additions to documentationbugSomething isn't workingSomething isn't working