Skip to content

🐛[BUG]: Modules defined in main script cannot be loaded #1169

@nbren12

Description

@nbren12

Version

1.2.0

On which installation method(s) does this occur?

No response

Describe the issue

I defined a module in a training script, providing a name as an argument to the metadata. This model cannot be loaded from another script even if I make sure to import the train module.

I think I am confused about what provided information is used to register a new module.

Is it the name field physicsnemo.ModelMetaData? This was what I expected.

Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?

Minimum reproducible example

# train.py
from dataclasses import dataclass
import physicsnemo
import torch.nn as nn
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint


# modified from https://docs.nvidia.com/physicsnemo/latest/physicsnemo/api/models/modules.html#how-to-write-your-own-physicsnemo-model


class UNetExample(physicsnemo.Module):
    def __init__(self, in_channels=1, out_channels=1, outc=None):
        super().__init__(meta=physicsnemo.ModelMetaData(name="UnetExample"))

        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)

        self.dec1 = self.upconv_block(128, 64)
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x = self.dec1(x2)
        return self.final(x)



if __name__ == "__main__":
    # normal save works
    unet = UNetExample()


Loading script

$ cat load_module.py
import physicsnemo
import train # imported so that model can register itself
import sys

physicsnemo.Module.from_checkpoint(sys.argv[1])

Relevant log output

$ python3 train.py
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
  pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
  key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
$ python3 load_module.py out.mdlus
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
  pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
  key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
Model {'__name__': 'UNetExample', '__module__': '__main__', '__args__': {'in_channels': 1, 'out_channels': 1, 'outc': None}}
Traceback (most recent call last):
  File "/lustre/fs1/portfolios/coreai/projects/coreai_climate_earth2/nbrenowitz/repos/edm-chaos/load_module.py", line 5, in <module>
    physicsnemo.Module.from_checkpoint(sys.argv[1])
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 517, in from_checkpoint
    model = Module.instantiate(args)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 278, in instantiate
    return _cls(**arg_dict["__args__"])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 74, in __new__
    bound_args = sig.bind_partial(
                 ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 3249, in bind_partial
    return self._bind(args, kwargs, partial=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 3231, in _bind
    raise TypeError(
TypeError: got an unexpected keyword argument 'in_channels'

Environment details

Metadata

Metadata

Assignees

Labels

! - ReleasePRs or Issues releating to a releaseDocumentationImprovements or additions to documentationbugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions