|
2 | 2 | from random import shuffle
|
3 | 3 | from typing import Dict, List, Optional, Tuple, Union
|
4 | 4 |
|
| 5 | +import keras |
5 | 6 | import numpy as np
|
6 |
| -import tensorflow as tf |
7 | 7 | from brainglobe_utils.cells.cells import Cell, group_cells_by_z
|
8 | 8 | from brainglobe_utils.general.numerical import is_even
|
| 9 | +from keras.utils import Sequence |
9 | 10 | from scipy.ndimage import zoom
|
10 | 11 | from skimage.io import imread
|
11 |
| -from tensorflow.keras.utils import Sequence |
12 | 12 |
|
13 | 13 | from cellfinder.core import types
|
14 | 14 | from cellfinder.core.classify.augment import AugmentationParameters, augment
|
@@ -56,7 +56,14 @@ def __init__(
|
56 | 56 | translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
|
57 | 57 | shuffle: bool = False,
|
58 | 58 | interpolation_order: int = 2,
|
| 59 | + *args, |
| 60 | + **kwargs, |
59 | 61 | ):
|
| 62 | + # pass any additional arguments not specified in signature to the |
| 63 | + # constructor of the superclass (e.g.: `use_multiprocessing` or |
| 64 | + # `workers`) |
| 65 | + super().__init__(*args, **kwargs) |
| 66 | + |
60 | 67 | self.points = points
|
61 | 68 | self.signal_array = signal_array
|
62 | 69 | self.background_array = background_array
|
@@ -220,7 +227,7 @@ def __getitem__(
|
220 | 227 |
|
221 | 228 | if self.train:
|
222 | 229 | batch_labels = [cell.type - 1 for cell in cell_batch]
|
223 |
| - batch_labels = tf.keras.utils.to_categorical( |
| 230 | + batch_labels = keras.utils.to_categorical( |
224 | 231 | batch_labels, num_classes=self.classes
|
225 | 232 | )
|
226 | 233 | return images, batch_labels
|
@@ -352,7 +359,14 @@ def __init__(
|
352 | 359 | translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
|
353 | 360 | train: bool = False, # also return labels
|
354 | 361 | interpolation_order: int = 2,
|
| 362 | + *args, |
| 363 | + **kwargs, |
355 | 364 | ):
|
| 365 | + # pass any additional arguments not specified in signature to the |
| 366 | + # constructor of the superclass (e.g.: `use_multiprocessing` or |
| 367 | + # `workers`) |
| 368 | + super().__init__(*args, **kwargs) |
| 369 | + |
356 | 370 | self.im_shape = shape
|
357 | 371 | self.batch_size = batch_size
|
358 | 372 | self.labels = labels
|
@@ -414,7 +428,7 @@ def __getitem__(
|
414 | 428 |
|
415 | 429 | if self.train and self.labels is not None:
|
416 | 430 | batch_labels = [self.labels[k] for k in indexes]
|
417 |
| - batch_labels = tf.keras.utils.to_categorical( |
| 431 | + batch_labels = keras.utils.to_categorical( |
418 | 432 | batch_labels, num_classes=self.classes
|
419 | 433 | )
|
420 | 434 | return images, batch_labels
|
|
0 commit comments