Skip to content

Commit cbdecaf

Browse files
IgorTatarnikovsfmigK-Meechmathampre-commit-ci[bot]
authored
Updating to Keras 3.0 and migrating to PyTorch (#418)
* remove pytest-lazy-fixture as dev dependency and skip test (with WG temp fix) * Test Keras is present (#374) * check if Keras present * change TF to Keras in CI * remove comment * change dependencies in pyproject.toml for Keras 3.0 * Migrate to Keras 3.0 with TF backend (#373) * remove pytest-lazy-fixture as dev dependency and skip test (with WG temp fix) * change tensorflow dependency for cellfinder * replace keras imports from tensorflow to just keras imports * add keras import and reorder * add keras and TF 2.16 to pyproject.toml * comment out TF version check for now * change checkpoint filename for compliance with keras 3. remove use_multiprocessing=False from fit() as it is no longer an input. test_train() passing * add multiprocessing parameters to cube generator constructor and remove from fit() signature (keras3 change) * apply temp garbage collector fix * skip troublesome test * skip running tests on CI on windows * remove commented out TF check * clean commented out code. Explicitly pass use_multiprocessing=False (as before) * remove str conversion before model.save * raise test_detection error for sonarcloud happy * skip running tests on windows on CI * remove filename comment and small edits * Replace TF references in comments and warning messages (#378) * change some old references to TF for the import check * change TF cached model to Keras * Cellfinder with Keras 3.0 and jax backend (#379) * replace tensorflow Tensor with keras tensor * add case for TF prep in prep_model_weights * add different backends to pyproject.toml * add backend configuration to cellfinder init file. tests passing with jax locally * define extra dependencies for cellfinder with different backends. run tox with TF backend * run tox using TF and JAX backend * install TF in brainmapper environment before running tests in CI * add backends check to cellfinder init file * clean up comments * fix tf-nightly import check * specify TF backend in include guard check * clarify comment * remove 'backend' from dependencies specifications * Apply suggestions from code review Co-authored-by: Igor Tatarnikov <[email protected]> --------- Co-authored-by: Igor Tatarnikov <[email protected]> * Run cellfinder with JAX in Windows tests in CI (#382) * use jax backend in brainmapper tests in CI * skip TF backend on windows * fix pip install cellfinder for brainmapper CI tests * add keras env variable for brainmapper CLI tests * fix prep_model_weights * It/keras3 pytorch (#396) * replace tensorflow Tensor with keras tensor * add case for TF prep in prep_model_weights * add different backends to pyproject.toml * add backend configuration to cellfinder init file. tests passing with jax locally * define extra dependencies for cellfinder with different backends. run tox with TF backend * run tox using TF and JAX backend * install TF in brainmapper environment before running tests in CI * add backends check to cellfinder init file * clean up comments * fix tf-nightly import check * specify TF backend in include guard check * clarify comment * remove 'backend' from dependencies specifications * Apply suggestions from code review Co-authored-by: Igor Tatarnikov <[email protected]> * PyTorch runs utilizing multiple cores * PyTorch fix with default models * Tests run on every push for now * Run test on torch backend only * Fixed guard test to set torch as KERAS_BACKEND * KERAS_BACKEND env variable set directly in test_include_guard.yaml * Run test on python 3.11 * Remove tf-nightly from __init__ version check * Added 3.11 to legacy tox config * Changed legacy tox config for real this time * Don't set the wrong max_processing value * Torch is now set as the default backend * Tests only run with torch, updated comments * Unpinned torch version * Add codecov token (#403) * add codecov token * generate xml coverage report * add timeout to testing jobs * Allow turning off classification or detection in GUI (#402) * Allow turning off classification or detection in GUI. * Fix test. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor to fix code analysis errors. * Ensure array is always 2d. * Apply suggestions from code review Co-authored-by: Igor Tatarnikov <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Tatarnikov <[email protected]> * Support single z-stack tif file for input (#397) * Support single z-stack tif file for input. * Fix commit hook. * Apply review suggestions. * Remove modular asv benchmarks (#406) * remove modular asv benchmarks * recover old structure * remove asv-specific lines from gitignore and manifest * prune benchmarks * Adapt CI so it covers both new and old Macs, and installs required additional dependencies on M1 (#408) * naive attempt at adapting to silicon mac CI * run include guard test on Silicon CI * double-check hdf5 is needed * Optimize cell detection (#398) (#407) * Replace coord map values with numba list/tuple for optim. * Switch to fortran layout for faster update of last dim. * Cache kernel. * jit ball filter. * Put z as first axis to speed z rolling (row-major memory). * Unroll recursion (no perf impact either way). * Parallelize cell cluster splitting. * Parallelize walking for full images. * Cleanup docs and pep8 etc. * Add pre-commit fixes. * Fix parallel always being selected and numba function 1st class warning. * Run hook. * Older python needs Union instead of |. * Accept review suggestion. * Address review changes. * num_threads must be an int. --------- Co-authored-by: Matt Einhorn <[email protected]> * [pre-commit.ci] pre-commit autoupdate (#412) updates: - [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v4.6.0](pre-commit/pre-commit-hooks@v4.5.0...v4.6.0) - [github.com/astral-sh/ruff-pre-commit: v0.3.5 → v0.4.3](astral-sh/ruff-pre-commit@v0.3.5...v0.4.3) - [github.com/psf/black: 24.3.0 → 24.4.2](psf/black@24.3.0...24.4.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: sfmig <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify model download (#414) * Simplify model download * Update model cache * Remove jax and tf tests * Standardise the data types for inputs to all be float32 * Force torch to use CPU on arm based macOS during tests * Added PYTORCH_MPS_HIGH_WATERMARK_RATION env variable * Set env variables in test setup * Try to set the default device to cpu in the test itself * Add device call to Conv3D to force cpu * Revert changes, request one cpu left free * Revers the numb cores, don't use arm based mac runner * Merged main, removed torch flags on cellfinder install for guards and brainmapper * Lowercase Torch * Change cache directory --------- Co-authored-by: sfmig <[email protected]> Co-authored-by: Kimberly Meechan <[email protected]> Co-authored-by: Matt Einhorn <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessandro Felder <[email protected]> Co-authored-by: Adam Tyson <[email protected]> * Set pooling padding to valid by default on all MaxPooling3D layers * Removed tf error suppression and other tf related functions * Force torch to use cpu device when CELLFINDER_TEST_DEVICE env variable set to cpu * Added nev variable to test step * Use the GITHUB ACTIONS environemntal variable instead * Added docstring for fixture setting device to cpu on arm based mac * Revert changes to no_free_cpus being fixture, and default param * Fixed typo in test_and_deploy.yml * Set multiprocessing to false for the data generators * Update all cache steps to match * Remove reference to TF * Make sure tests can run locally when GITHUB_ACTIONS env variable is missing2 * Removed warning when backend is not configured * Set the label tensor to be float32 to ensure compatibility with mps * Always set KERAS_BACKEND to torch on init * Remove code in __init__ checking for if backend is installed --------- Co-authored-by: sfmig <[email protected]> Co-authored-by: Kimberly Meechan <[email protected]> Co-authored-by: Matt Einhorn <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessandro Felder <[email protected]> Co-authored-by: Adam Tyson <[email protected]>
1 parent de834d2 commit cbdecaf

15 files changed

+157
-177
lines changed

.github/workflows/test_and_deploy.yml

+21-11
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ jobs:
3737
name: Run package tests
3838
timeout-minutes: 60
3939
runs-on: ${{ matrix.os }}
40+
env:
41+
KERAS_BACKEND: torch
42+
CELLFINDER_TEST_DEVICE: cpu
4043
strategy:
4144
matrix:
4245
# Run all supported Python versions on linux
4346
os: [ubuntu-latest]
44-
python-version: ["3.9", "3.10"]
45-
# Include one windows, one macos run each for M1 (latest) and Intel (13)
47+
python-version: ["3.9", "3.10", "3.11"]
48+
# Include one windows and two macOS (intel based and arm based) runs
4649
include:
4750
- os: macos-13
4851
python-version: "3.10"
@@ -80,11 +83,13 @@ jobs:
8083
NUMBA_DISABLE_JIT: "1"
8184

8285
steps:
83-
- name: Cache tensorflow model
86+
- name: Cache brainglobe directory
8487
uses: actions/cache@v3
8588
with:
86-
path: "~/.cellfinder"
87-
key: models-${{ hashFiles('~/.brainglobe/**') }}
89+
path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
90+
~/.brainglobe
91+
!~/.brainglobe/atlas.tar.gz
92+
key: brainglobe
8893
# Setup pyqt libraries
8994
- name: Setup qtpy libraries
9095
uses: tlambert03/setup-qt-libs@v1
@@ -104,13 +109,17 @@ jobs:
104109
name: Run brainmapper tests to check for breakages
105110
timeout-minutes: 60
106111
runs-on: ubuntu-latest
112+
env:
113+
KERAS_BACKEND: torch
114+
CELLFINDER_TEST_DEVICE: cpu
107115
steps:
108-
- name: Cache tensorflow model
116+
- name: Cache brainglobe directory
109117
uses: actions/cache@v3
110118
with:
111-
path: "~/.cellfinder"
112-
key: models-${{ hashFiles('~/.brainglobe/**') }}
113-
119+
path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
120+
~/.brainglobe
121+
!~/.brainglobe/atlas.tar.gz
122+
key: brainglobe
114123
- name: Checkout brainglobe-workflows
115124
uses: actions/checkout@v3
116125
with:
@@ -124,8 +133,9 @@ jobs:
124133
- name: Install test dependencies
125134
run: |
126135
python -m pip install --upgrade pip wheel
127-
# Install latest SHA on this brainglobe-workflows branch
128-
python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA
136+
# Install cellfinder from the latest SHA on this branch
137+
python -m pip install "cellfinder @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA"
138+
129139
# Install checked out copy of brainglobe-workflows
130140
python -m pip install .[dev]
131141

.github/workflows/test_include_guard.yaml

+9-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
name: Test Tensorflow include guards
2-
# These tests check that the include guards checking for tensorflow's availability
1+
name: Test Keras include guards
2+
# These tests check that the include guards checking for Keras availability
33
# behave as expected on ubuntu and macOS.
44

55
on:
@@ -9,7 +9,7 @@ on:
99
- main
1010

1111
jobs:
12-
tensorflow_guards:
12+
keras_guards:
1313
name: Test include guards
1414
strategy:
1515
matrix:
@@ -24,24 +24,21 @@ jobs:
2424
with:
2525
python-version: '3.10'
2626

27-
- name: Install via pip
28-
run: python -m pip install -e .
27+
- name: Install cellfinder via pip
28+
run: python -m pip install -e "."
2929

3030
- name: Test (working) import
3131
uses: jannekem/run-python-script-action@v1
32+
env:
33+
KERAS_BACKEND: torch
3234
with:
3335
fail-on-error: true
3436
script: |
3537
import cellfinder.core
3638
import cellfinder.napari
3739
38-
- name: Uninstall tensorflow-macos on Mac M1
39-
if: matrix.os == 'macos-latest'
40-
run: python -m pip uninstall -y tensorflow-macos
41-
42-
- name: Uninstall tensorflow on Ubuntu
43-
if: matrix.os == 'ubuntu-latest'
44-
run: python -m pip uninstall -y tensorflow
40+
- name: Uninstall keras
41+
run: python -m pip uninstall -y keras
4542

4643
- name: Test (broken) import
4744
id: broken_import

cellfinder/__init__.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
1+
import os
12
from importlib.metadata import PackageNotFoundError, version
23
from pathlib import Path
34

5+
# Check cellfinder is installed
46
try:
57
__version__ = version("cellfinder")
68
except PackageNotFoundError as e:
79
raise PackageNotFoundError("cellfinder package not installed") from e
810

9-
# If tensorflow is not present, tools cannot be used.
11+
# If Keras is not present, tools cannot be used.
1012
# Throw an error in this case to prevent invocation of functions.
1113
try:
12-
TF_VERSION = version("tensorflow")
14+
KERAS_VERSION = version("keras")
1315
except PackageNotFoundError as e:
14-
try:
15-
TF_VERSION = version("tensorflow-macos")
16-
except PackageNotFoundError as e:
17-
raise PackageNotFoundError(
18-
f"cellfinder tools cannot be invoked without tensorflow. "
19-
f"Please install tensorflow into your environment to use cellfinder tools. "
20-
f"For more information, please see "
21-
f"https://github.com/brainglobe/brainglobe-meta#readme."
22-
) from e
16+
raise PackageNotFoundError(
17+
f"cellfinder tools cannot be invoked without Keras. "
18+
f"Please install Keras with a backend into your environment "
19+
f"to use cellfinder tools. "
20+
f"For more information on Keras backends, please see "
21+
f"https://keras.io/getting_started/#installing-keras-3."
22+
f"For more information on brainglobe, please see "
23+
f"https://github.com/brainglobe/brainglobe-meta#readme."
24+
) from e
25+
26+
27+
# Set the Keras backend to torch
28+
os.environ["KERAS_BACKEND"] = "torch"
2329

2430
__author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
2531
__license__ = "BSD-3-Clause"

cellfinder/core/classify/classify.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional, Tuple
33

4+
import keras
45
import numpy as np
56
from brainglobe_utils.cells.cells import Cell
67
from brainglobe_utils.general.system import get_num_processes
7-
from tensorflow import keras
88

99
from cellfinder.core import logger, types
1010
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
@@ -48,9 +48,7 @@ def main(
4848
callbacks = None
4949

5050
# Too many workers doesn't increase speed, and uses huge amounts of RAM
51-
workers = get_num_processes(
52-
min_free_cpu_cores=n_free_cpus, n_max_processes=max_workers
53-
)
51+
workers = get_num_processes(min_free_cpu_cores=n_free_cpus)
5452

5553
logger.debug("Initialising cube generator")
5654
inference_generator = CubeGeneratorFromFile(
@@ -63,6 +61,8 @@ def main(
6361
cube_width=cube_width,
6462
cube_height=cube_height,
6563
cube_depth=cube_depth,
64+
use_multiprocessing=False,
65+
workers=workers,
6666
)
6767

6868
model = get_model(
@@ -73,10 +73,9 @@ def main(
7373
)
7474

7575
logger.info("Running inference")
76+
# in Keras 3.0 multiprocessing params are specified in the generator
7677
predictions = model.predict(
7778
inference_generator,
78-
use_multiprocessing=True,
79-
workers=workers,
8079
verbose=True,
8180
callbacks=callbacks,
8281
)

cellfinder/core/classify/cube_generator.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from random import shuffle
33
from typing import Dict, List, Optional, Tuple, Union
44

5+
import keras
56
import numpy as np
6-
import tensorflow as tf
77
from brainglobe_utils.cells.cells import Cell, group_cells_by_z
88
from brainglobe_utils.general.numerical import is_even
9+
from keras.utils import Sequence
910
from scipy.ndimage import zoom
1011
from skimage.io import imread
11-
from tensorflow.keras.utils import Sequence
1212

1313
from cellfinder.core import types
1414
from cellfinder.core.classify.augment import AugmentationParameters, augment
@@ -56,7 +56,14 @@ def __init__(
5656
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
5757
shuffle: bool = False,
5858
interpolation_order: int = 2,
59+
*args,
60+
**kwargs,
5961
):
62+
# pass any additional arguments not specified in signature to the
63+
# constructor of the superclass (e.g.: `use_multiprocessing` or
64+
# `workers`)
65+
super().__init__(*args, **kwargs)
66+
6067
self.points = points
6168
self.signal_array = signal_array
6269
self.background_array = background_array
@@ -218,10 +225,10 @@ def __getitem__(self, index: int) -> Union[
218225

219226
if self.train:
220227
batch_labels = [cell.type - 1 for cell in cell_batch]
221-
batch_labels = tf.keras.utils.to_categorical(
228+
batch_labels = keras.utils.to_categorical(
222229
batch_labels, num_classes=self.classes
223230
)
224-
return images, batch_labels
231+
return images, batch_labels.astype(np.float32)
225232
elif self.extract:
226233
batch_info = self.__get_batch_dict(cell_batch)
227234
return images, batch_info
@@ -252,7 +259,8 @@ def __generate_cubes(
252259
(number_images,)
253260
+ (self.cube_height, self.cube_width, self.cube_depth)
254261
+ (self.channels,)
255-
)
262+
),
263+
dtype=np.float32,
256264
)
257265

258266
for idx, cell in enumerate(cell_batch):
@@ -350,7 +358,14 @@ def __init__(
350358
translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
351359
train: bool = False, # also return labels
352360
interpolation_order: int = 2,
361+
*args,
362+
**kwargs,
353363
):
364+
# pass any additional arguments not specified in signature to the
365+
# constructor of the superclass (e.g.: `use_multiprocessing` or
366+
# `workers`)
367+
super().__init__(*args, **kwargs)
368+
354369
self.im_shape = shape
355370
self.batch_size = batch_size
356371
self.labels = labels
@@ -410,10 +425,10 @@ def __getitem__(self, index: int) -> Union[
410425

411426
if self.train and self.labels is not None:
412427
batch_labels = [self.labels[k] for k in indexes]
413-
batch_labels = tf.keras.utils.to_categorical(
428+
batch_labels = keras.utils.to_categorical(
414429
batch_labels, num_classes=self.classes
415430
)
416-
return images, batch_labels
431+
return images, batch_labels.astype(np.float32)
417432
else:
418433
return images
419434

@@ -424,7 +439,8 @@ def __generate_cubes(
424439
) -> np.ndarray:
425440
number_images = len(list_signal_tmp)
426441
images = np.empty(
427-
((number_images,) + self.im_shape + (self.channels,))
442+
((number_images,) + self.im_shape + (self.channels,)),
443+
dtype=np.float32,
428444
)
429445

430446
for idx, signal_im in enumerate(list_signal_tmp):
@@ -433,7 +449,7 @@ def __generate_cubes(
433449
images, idx, signal_im, background_im
434450
)
435451

436-
return images.astype(np.float16)
452+
return images
437453

438454
def __populate_array_with_cubes(
439455
self,

cellfinder/core/classify/resnet.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
22

3-
from tensorflow import Tensor
4-
from tensorflow.keras import Model
5-
from tensorflow.keras.initializers import Initializer
6-
from tensorflow.keras.layers import (
3+
from keras import (
4+
KerasTensor as Tensor,
5+
)
6+
from keras import Model
7+
from keras.initializers import Initializer
8+
from keras.layers import (
79
Activation,
810
Add,
911
BatchNormalization,
@@ -14,7 +16,7 @@
1416
MaxPooling3D,
1517
ZeroPadding3D,
1618
)
17-
from tensorflow.keras.optimizers import Adam, Optimizer
19+
from keras.optimizers import Adam, Optimizer
1820

1921
#####################################################################
2022
# Define the types of ResNet
@@ -113,7 +115,7 @@ def non_residual_block(
113115
activation: str = "relu",
114116
use_bias: bool = False,
115117
bn_epsilon: float = 1e-5,
116-
pooling_padding: str = "same",
118+
pooling_padding: str = "valid",
117119
axis: int = 3,
118120
) -> Tensor:
119121
"""
@@ -131,6 +133,7 @@ def non_residual_block(
131133
)(x)
132134
x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x)
133135
x = Activation(activation, name="conv1_activation")(x)
136+
134137
x = MaxPooling3D(
135138
max_pool_size,
136139
strides=strides,

0 commit comments

Comments
 (0)