Skip to content

Commit 29f8555

Browse files
authored
Migrate to Keras 3.0 with TF backend (#373)
* remove pytest-lazy-fixture as dev dependency and skip test (with WG temp fix) * change tensorflow dependency for cellfinder * replace keras imports from tensorflow to just keras imports * add keras import and reorder * add keras and TF 2.16 to pyproject.toml * comment out TF version check for now * change checkpoint filename for compliance with keras 3. remove use_multiprocessing=False from fit() as it is no longer an input. test_train() passing * add multiprocessing parameters to cube generator constructor and remove from fit() signature (keras3 change) * apply temp garbage collector fix * skip troublesome test * skip running tests on CI on windows * remove commented out TF check * clean commented out code. Explicitly pass use_multiprocessing=False (as before) * remove str conversion before model.save * raise test_detection error for sonarcloud happy * skip running tests on windows on CI * remove filename comment and small edits
1 parent 99cbda0 commit 29f8555

File tree

9 files changed

+72
-65
lines changed

9 files changed

+72
-65
lines changed

.github/workflows/test_and_deploy.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,10 @@ jobs:
4242
# Run all supported Python versions on linux
4343
os: [ubuntu-latest]
4444
python-version: ["3.9", "3.10"]
45-
# Include one windows, one macos run
45+
# Include one macos run
4646
include:
4747
- os: macos-latest
4848
python-version: "3.10"
49-
- os: windows-latest
50-
python-version: "3.10"
5149

5250
steps:
5351
# Cache the tensorflow model so we don't have to remake it every time

cellfinder/core/classify/classify.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional, Tuple
33

4+
import keras
45
import numpy as np
56
from brainglobe_utils.cells.cells import Cell
67
from brainglobe_utils.general.system import get_num_processes
7-
from tensorflow import keras
88

99
from cellfinder.core import logger, types
1010
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
@@ -63,6 +63,8 @@ def main(
6363
cube_width=cube_width,
6464
cube_height=cube_height,
6565
cube_depth=cube_depth,
66+
use_multiprocessing=True,
67+
workers=workers,
6668
)
6769

6870
model = get_model(
@@ -73,10 +75,9 @@ def main(
7375
)
7476

7577
logger.info("Running inference")
78+
# in Keras 3.0 multiprocessing params are specified in the generator
7679
predictions = model.predict(
7780
inference_generator,
78-
use_multiprocessing=True,
79-
workers=workers,
8081
verbose=True,
8182
callbacks=callbacks,
8283
)

cellfinder/core/classify/cube_generator.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from random import shuffle
33
from typing import Dict, List, Optional, Tuple, Union
44

5+
import keras
56
import numpy as np
6-
import tensorflow as tf
77
from brainglobe_utils.cells.cells import Cell, group_cells_by_z
88
from brainglobe_utils.general.numerical import is_even
9+
from keras.utils import Sequence
910
from scipy.ndimage import zoom
1011
from skimage.io import imread
11-
from tensorflow.keras.utils import Sequence
1212

1313
from cellfinder.core import types
1414
from cellfinder.core.classify.augment import AugmentationParameters, augment
@@ -56,7 +56,14 @@ def __init__(
5656
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
5757
shuffle: bool = False,
5858
interpolation_order: int = 2,
59+
*args,
60+
**kwargs,
5961
):
62+
# pass any additional arguments not specified in signature to the
63+
# constructor of the superclass (e.g.: `use_multiprocessing` or
64+
# `workers`)
65+
super().__init__(*args, **kwargs)
66+
6067
self.points = points
6168
self.signal_array = signal_array
6269
self.background_array = background_array
@@ -220,7 +227,7 @@ def __getitem__(
220227

221228
if self.train:
222229
batch_labels = [cell.type - 1 for cell in cell_batch]
223-
batch_labels = tf.keras.utils.to_categorical(
230+
batch_labels = keras.utils.to_categorical(
224231
batch_labels, num_classes=self.classes
225232
)
226233
return images, batch_labels
@@ -352,7 +359,14 @@ def __init__(
352359
translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
353360
train: bool = False, # also return labels
354361
interpolation_order: int = 2,
362+
*args,
363+
**kwargs,
355364
):
365+
# pass any additional arguments not specified in signature to the
366+
# constructor of the superclass (e.g.: `use_multiprocessing` or
367+
# `workers`)
368+
super().__init__(*args, **kwargs)
369+
356370
self.im_shape = shape
357371
self.batch_size = batch_size
358372
self.labels = labels
@@ -414,7 +428,7 @@ def __getitem__(
414428

415429
if self.train and self.labels is not None:
416430
batch_labels = [self.labels[k] for k in indexes]
417-
batch_labels = tf.keras.utils.to_categorical(
431+
batch_labels = keras.utils.to_categorical(
418432
batch_labels, num_classes=self.classes
419433
)
420434
return images, batch_labels

cellfinder/core/classify/resnet.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
22

3-
from tensorflow import Tensor
4-
from tensorflow.keras import Model
5-
from tensorflow.keras.initializers import Initializer
6-
from tensorflow.keras.layers import (
3+
from keras import Model
4+
from keras.initializers import Initializer
5+
from keras.layers import (
76
Activation,
87
Add,
98
BatchNormalization,
@@ -14,7 +13,8 @@
1413
MaxPooling3D,
1514
ZeroPadding3D,
1615
)
17-
from tensorflow.keras.optimizers import Adam, Optimizer
16+
from keras.optimizers import Adam, Optimizer
17+
from tensorflow import Tensor
1818

1919
#####################################################################
2020
# Define the types of ResNet

cellfinder/core/classify/tools.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2-
from typing import List, Optional, Sequence, Tuple, Union
2+
from collections.abc import Sequence
3+
from typing import List, Optional, Tuple, Union
34

5+
import keras
46
import numpy as np
5-
import tensorflow as tf
6-
from tensorflow.keras import Model
7+
from keras import Model
78

89
from cellfinder.core import logger
910
from cellfinder.core.classify.resnet import build_model, layer_type
@@ -17,8 +18,7 @@ def get_model(
1718
inference: bool = False,
1819
continue_training: bool = False,
1920
) -> Model:
20-
"""
21-
Returns the correct model based on the arguments passed
21+
"""Returns the correct model based on the arguments passed
2222
:param existing_model: An existing, trained model. This is returned if it
2323
exists
2424
:param model_weights: This file is used to set the model weights if it
@@ -30,29 +30,31 @@ def get_model(
3030
by using the default one
3131
:param continue_training: If True, will ensure that a trained model
3232
exists. E.g. by using the default one
33-
:return: A tf.keras model
33+
:return: A keras model
3434
3535
"""
3636
if existing_model is not None or network_depth is None:
3737
logger.debug(f"Loading model: {existing_model}")
38-
return tf.keras.models.load_model(existing_model)
38+
return keras.models.load_model(existing_model)
3939
else:
4040
logger.debug(f"Creating a new instance of model: {network_depth}")
4141
model = build_model(
42-
network_depth=network_depth, learning_rate=learning_rate
42+
network_depth=network_depth,
43+
learning_rate=learning_rate,
4344
)
4445
if inference or continue_training:
4546
logger.debug(
46-
f"Setting model weights according to: {model_weights}"
47+
f"Setting model weights according to: {model_weights}",
4748
)
4849
if model_weights is None:
49-
raise IOError("`model_weights` must be provided")
50+
raise OSError("`model_weights` must be provided")
5051
model.load_weights(model_weights)
5152
return model
5253

5354

5455
def make_lists(
55-
tiff_files: Sequence, train: bool = True
56+
tiff_files: Sequence,
57+
train: bool = True,
5658
) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]:
5759
signal_list = []
5860
background_list = []

cellfinder/core/train/train_yml.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def run(
324324

325325
suppress_tf_logging(tf_suppress_log_messages)
326326

327-
from tensorflow.keras.callbacks import (
327+
from keras.callbacks import (
328328
CSVLogger,
329329
ModelCheckpoint,
330330
TensorBoard,
@@ -386,15 +386,16 @@ def run(
386386
labels=labels_test,
387387
batch_size=batch_size,
388388
train=True,
389+
use_multiprocessing=False,
389390
)
390391

391392
# for saving checkpoints
392-
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"
393+
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"
393394

394395
else:
395396
logger.info("No validation data selected.")
396397
validation_generator = None
397-
base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"
398+
base_checkpoint_file_name = "-epoch.{epoch:02d}"
398399

399400
training_generator = CubeGeneratorFromDisk(
400401
signal_train,
@@ -404,6 +405,7 @@ def run(
404405
shuffle=True,
405406
train=True,
406407
augment=not no_augment,
408+
use_multiprocessing=False,
407409
)
408410
callbacks = []
409411

@@ -420,9 +422,14 @@ def run(
420422

421423
if not no_save_checkpoints:
422424
if save_weights:
423-
filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
425+
filepath = str(
426+
output_dir
427+
/ ("weight" + base_checkpoint_file_name + ".weights.h5")
428+
)
424429
else:
425-
filepath = str(output_dir / ("model" + base_checkpoint_file_name))
430+
filepath = str(
431+
output_dir / ("model" + base_checkpoint_file_name + ".keras")
432+
)
426433

427434
checkpoints = ModelCheckpoint(
428435
filepath,
@@ -431,25 +438,26 @@ def run(
431438
callbacks.append(checkpoints)
432439

433440
if save_progress:
434-
filepath = str(output_dir / "training.csv")
435-
csv_logger = CSVLogger(filepath)
441+
csv_filepath = str(output_dir / "training.csv")
442+
csv_logger = CSVLogger(csv_filepath)
436443
callbacks.append(csv_logger)
437444

438445
logger.info("Beginning training.")
446+
# Keras 3.0: `use_multiprocessing` input is set in the
447+
# `training_generator` (False by default)
439448
model.fit(
440449
training_generator,
441450
validation_data=validation_generator,
442-
use_multiprocessing=False,
443451
epochs=epochs,
444452
callbacks=callbacks,
445453
)
446454

447455
if save_weights:
448456
logger.info("Saving model weights")
449-
model.save_weights(str(output_dir / "model_weights.h5"))
457+
model.save_weights(output_dir / "model.weights.h5")
450458
else:
451459
logger.info("Saving model")
452-
model.save(output_dir / "model.h5")
460+
model.save(output_dir / "model.keras")
453461

454462
logger.info(
455463
"Finished training, " "Total time taken: %s",

tests/core/conftest.py

-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import Tuple
32

43
import numpy as np
@@ -9,26 +8,6 @@
98
from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH
109

1110

12-
@pytest.fixture(scope="session")
13-
def no_free_cpus() -> int:
14-
"""
15-
Set number of free CPUs so all available CPUs are used by the tests.
16-
"""
17-
return 0
18-
19-
20-
@pytest.fixture(scope="session")
21-
def run_on_one_cpu_only() -> int:
22-
"""
23-
Set number of free CPUs so tests can use exactly one CPU.
24-
"""
25-
cpus = os.cpu_count()
26-
if cpus is not None:
27-
return cpus - 1
28-
else:
29-
raise ValueError("No CPUs available.")
30-
31-
3211
@pytest.fixture(scope="session")
3312
def download_default_model():
3413
"""

tests/core/test_integration/test_detection.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,13 @@ def test_detection_full(signal_array, background_array, free_cpus, request):
8181

8282

8383
def test_detection_small_planes(
84-
signal_array, background_array, no_free_cpus, mocker
84+
signal_array,
85+
background_array,
86+
mocker,
87+
cpus_to_leave_free: int = 0,
8588
):
8689
# Check that processing works when number of planes < number of processes
87-
nproc = get_num_processes(no_free_cpus)
90+
nproc = get_num_processes(cpus_to_leave_free)
8891
n_planes = 2
8992

9093
# Don't want to bother classifying in this test, so mock classifcation
@@ -101,11 +104,13 @@ def test_detection_small_planes(
101104
background_array[0:n_planes],
102105
voxel_sizes,
103106
ball_z_size=5,
104-
n_free_cpus=no_free_cpus,
107+
n_free_cpus=cpus_to_leave_free,
105108
)
106109

107110

108-
def test_callbacks(signal_array, background_array, no_free_cpus):
111+
def test_callbacks(
112+
signal_array, background_array, cpus_to_leave_free: int = 0
113+
):
109114
# 20 is minimum number of planes needed to find > 0 cells
110115
signal_array = signal_array[0:20]
111116
background_array = background_array[0:20]
@@ -130,7 +135,7 @@ def detect_finished_callback(points):
130135
detect_callback=detect_callback,
131136
classify_callback=classify_callback,
132137
detect_finished_callback=detect_finished_callback,
133-
n_free_cpus=no_free_cpus,
138+
n_free_cpus=cpus_to_leave_free,
134139
)
135140

136141
np.testing.assert_equal(planes_done, np.arange(len(signal_array)))
@@ -148,13 +153,13 @@ def test_floating_point_error(signal_array, background_array):
148153
main(signal_array, background_array, voxel_sizes)
149154

150155

151-
def test_synthetic_data(synthetic_bright_spots, no_free_cpus):
156+
def test_synthetic_data(synthetic_bright_spots, cpus_to_leave_free: int = 0):
152157
signal_array, background_array = synthetic_bright_spots
153158
detected = main(
154159
signal_array,
155160
background_array,
156161
voxel_sizes,
157-
n_free_cpus=no_free_cpus,
162+
n_free_cpus=cpus_to_leave_free,
158163
)
159164
assert len(detected) == 8
160165

tests/core/test_integration/test_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ def test_train(tmpdir):
3535
sys.argv = train_args
3636
train_run()
3737

38-
model_file = os.path.join(tmpdir, "model.h5")
38+
model_file = os.path.join(tmpdir, "model.keras")
3939
assert os.path.exists(model_file)

0 commit comments

Comments
 (0)