Skip to content

The inference checkpoint (using the pretrained MTLSD checkpoint) is not compatible with PyTorch. #26

@yumang1cv

Description

@yumang1cv

Recently, while running the inference step in this Colab-based tutorial (using the pre-trained MTLSD checkpoint), I encountered an error with the following code:

checkpoint = 'model_checkpoint_50000'  
raw_file = 'testing_data.zarr'  
raw_dataset = 'raw/0'  

raw, pred_lsds, pred_affs = predict(checkpoint, raw_file, raw_dataset)  

The error I received was:

Traceback (most recent call last):  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_provider.py", line 193, in request_batch  
    batch = self.provide(upstream_request)  
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_filter.py", line 148, in provide  
    dependencies = self.prepare(request)  
                   ^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/generic_predict.py", line 116, in prepare  
    self.start()  
  File "/usr/local/lib/python3.11/dist-packages/gunpowder/torch/nodes/predict.py", line 105, in start  
    checkpoint = torch.load(self.checkpoint, map_location=self.device)  
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1384, in load  
    return _legacy_load(  
           ^^^^^^^^^^^^^  
  File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1628, in _legacy_load  
    magic_number = pickle_module.load(f, **pickle_load_args)  
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  
_pickle.UnpicklingError: invalid load key, '<'.  

Exception in pipeline:  
ZarrSource[testing_data.zarr] -> Normalize -> Unsqueeze -> Stack -> Predict -> Scan -> Squeeze -> Squeeze  
while trying to process request  

    RAW: ROI: [0:5000, 0:5000] (5000, 5000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  
    PRED_LSDS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  
    PRED_AFFS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False  

I was able to download the checkpoint successfully, as shown below:

--2025-02-10 09:22:56--  https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000  
Resolving www.dropbox.com (www.dropbox.com)... 162.125.65.18, 2620:100:6021:18::a27d:4112  
Connecting to www.dropbox.com (www.dropbox.com)|162.125.65.18|:443... connected.  
HTTP request sent, awaiting response... 200 OK  
Length: unspecified [text/html]  
Saving to: ‘model_checkpoint_50000’  

model_checkpoint_50     [  <=>               ]  72.45K   209KB/s    in 0.3s    

2025-02-10 09:22:58 (209 KB/s) - ‘model_checkpoint_50000’ saved [74187]  

After inspecting the issue, I realized that this checkpoint is likely based on TensorFlow and is therefore incompatible with PyTorch. Could you please provide a PyTorch-compatible version of the checkpoint to resolve this issue?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions