Skip to content

Commit c685d24

Browse files
fix: ⚡️ Improve passing kwargs to PyTorch dataloader
1 parent 37be265 commit c685d24

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

src/cellmap_segmentation_challenge/utils/dataloader.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -149,27 +149,33 @@ def get_dataloader(
149149
)
150150

151151
if len(datasplit.validation_datasets) >= 0:
152+
_kwargs = {
153+
"classes": classes,
154+
"batch_size": batch_size,
155+
"is_train": random_validation,
156+
"device": device,
157+
}
158+
_kwargs.update(kwargs)
152159
validation_loader = CellMapDataLoader(
153160
datasplit.validation_blocks.to(device),
154-
classes=classes,
155-
batch_size=batch_size,
156-
is_train=random_validation,
157-
device=device,
158-
**kwargs,
161+
**_kwargs,
159162
)
160163
else:
161164
validation_loader = None
162165

163-
train_loader = CellMapDataLoader(
164-
datasplit.train_datasets_combined.to(device),
165-
classes=classes,
166-
batch_size=batch_size,
167-
sampler=lambda: datasplit.train_datasets_combined.get_subset_random_sampler(
166+
_kwargs = {
167+
"classes": classes,
168+
"batch_size": batch_size,
169+
"sampler": lambda: datasplit.train_datasets_combined.get_subset_random_sampler(
168170
iterations_per_epoch * batch_size, weighted=weighted_sampler
169171
),
170-
device=device,
171-
is_train=True,
172-
**kwargs,
172+
"device": device,
173+
"is_train": True,
174+
}
175+
_kwargs.update(kwargs)
176+
train_loader = CellMapDataLoader(
177+
datasplit.train_datasets_combined.to(device),
178+
**_kwargs,
173179
)
174180

175181
return train_loader, validation_loader # type: ignore

0 commit comments

Comments
 (0)