-
Notifications
You must be signed in to change notification settings - Fork 40
Open
Description
Recently, while running the inference step in this Colab-based tutorial (using the pre-trained MTLSD checkpoint), I encountered an error with the following code:
checkpoint = 'model_checkpoint_50000'
raw_file = 'testing_data.zarr'
raw_dataset = 'raw/0'
raw, pred_lsds, pred_affs = predict(checkpoint, raw_file, raw_dataset) The error I received was:
Traceback (most recent call last):
File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_provider.py", line 193, in request_batch
batch = self.provide(upstream_request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/batch_filter.py", line 148, in provide
dependencies = self.prepare(request)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/gunpowder/nodes/generic_predict.py", line 116, in prepare
self.start()
File "/usr/local/lib/python3.11/dist-packages/gunpowder/torch/nodes/predict.py", line 105, in start
checkpoint = torch.load(self.checkpoint, map_location=self.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1384, in load
return _legacy_load(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/serialization.py", line 1628, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: invalid load key, '<'.
Exception in pipeline:
ZarrSource[testing_data.zarr] -> Normalize -> Unsqueeze -> Stack -> Predict -> Scan -> Squeeze -> Squeeze
while trying to process request
RAW: ROI: [0:5000, 0:5000] (5000, 5000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
PRED_LSDS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
PRED_AFFS: ROI: [40:4960, 40:4960] (4920, 4920), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
I was able to download the checkpoint successfully, as shown below:
--2025-02-10 09:22:56-- https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000
Resolving www.dropbox.com (www.dropbox.com)... 162.125.65.18, 2620:100:6021:18::a27d:4112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.65.18|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘model_checkpoint_50000’
model_checkpoint_50 [ <=> ] 72.45K 209KB/s in 0.3s
2025-02-10 09:22:58 (209 KB/s) - ‘model_checkpoint_50000’ saved [74187]
After inspecting the issue, I realized that this checkpoint is likely based on TensorFlow and is therefore incompatible with PyTorch. Could you please provide a PyTorch-compatible version of the checkpoint to resolve this issue?
Metadata
Metadata
Assignees
Labels
No labels