Skip to content

Commit 317a854

Browse files
deploy changes
1 parent b4e0e31 commit 317a854

3 files changed

Lines changed: 19 additions & 1 deletion

File tree

asparagus/functional/loading.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import nibabel as nib
2+
import numpy as np
3+
import torch
4+
5+
6+
def load_image_file(file: str) -> torch.Tensor:
7+
if file.endswith(".pt"):
8+
return torch.load(file)
9+
elif file.endswith(".nii.gz") or file.endswith(".nii"):
10+
nii = nib.load(file)
11+
data = nii.get_fdata(dtype=np.float32)
12+
tensor = torch.from_numpy(data)
13+
return tensor.unsqueeze(0) # (H,W,D) -> (1,H,W,D) to match .pt channel convention
14+
else:
15+
raise ValueError(f"Unsupported file format: {file}. Expected .pt, .nii, or .nii.gz")

asparagus/modules/datasets/PretrainDataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torchvision
3+
from asparagus.functional.loading import load_image_file
34
from torch.utils.data import Dataset
45
from typing import Optional
56

@@ -20,7 +21,7 @@ def __len__(self):
2021

2122
def __getitem__(self, idx):
2223
file = self.files[idx]
23-
data = torch.load(file)
24+
data = load_image_file(file)
2425
data_dict = {"file_path": file, "image": data, "transforms_applied": {}}
2526
data_dict = self._transform(data_dict) # CPU transforms only here
2627

asparagus/modules/networks/unet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def unet_tiny(
8484
output_channels: int = 1,
8585
dimensions: str = "3D",
8686
deep_supervision: bool = False,
87+
use_skip_connections: bool = True,
8788
):
8889
return UNet(
8990
input_channels=input_channels,
@@ -93,6 +94,7 @@ def unet_tiny(
9394
encoder_basic_block=MultiLayerConvDropoutNormNonlin.get_block_constructor(1),
9495
decoder_basic_block=MultiLayerConvDropoutNormNonlin.get_block_constructor(1),
9596
deep_supervision=deep_supervision,
97+
use_skip_connections=use_skip_connections,
9698
)
9799

98100

0 commit comments

Comments
 (0)