File tree 1 file changed +19
-13
lines changed
src/cellmap_segmentation_challenge/utils
1 file changed +19
-13
lines changed Original file line number Diff line number Diff line change @@ -149,27 +149,33 @@ def get_dataloader(
149
149
)
150
150
151
151
if len (datasplit .validation_datasets ) >= 0 :
152
+ _kwargs = {
153
+ "classes" : classes ,
154
+ "batch_size" : batch_size ,
155
+ "is_train" : random_validation ,
156
+ "device" : device ,
157
+ }
158
+ _kwargs .update (kwargs )
152
159
validation_loader = CellMapDataLoader (
153
160
datasplit .validation_blocks .to (device ),
154
- classes = classes ,
155
- batch_size = batch_size ,
156
- is_train = random_validation ,
157
- device = device ,
158
- ** kwargs ,
161
+ ** _kwargs ,
159
162
)
160
163
else :
161
164
validation_loader = None
162
165
163
- train_loader = CellMapDataLoader (
164
- datasplit .train_datasets_combined .to (device ),
165
- classes = classes ,
166
- batch_size = batch_size ,
167
- sampler = lambda : datasplit .train_datasets_combined .get_subset_random_sampler (
166
+ _kwargs = {
167
+ "classes" : classes ,
168
+ "batch_size" : batch_size ,
169
+ "sampler" : lambda : datasplit .train_datasets_combined .get_subset_random_sampler (
168
170
iterations_per_epoch * batch_size , weighted = weighted_sampler
169
171
),
170
- device = device ,
171
- is_train = True ,
172
- ** kwargs ,
172
+ "device" : device ,
173
+ "is_train" : True ,
174
+ }
175
+ _kwargs .update (kwargs )
176
+ train_loader = CellMapDataLoader (
177
+ datasplit .train_datasets_combined .to (device ),
178
+ ** _kwargs ,
173
179
)
174
180
175
181
return train_loader , validation_loader # type: ignore
You can’t perform that action at this time.
0 commit comments