Skip to content

Torch backend#326

Merged
giovannivolpe merged 8 commits into
DeepTrackAI:torch-backendfrom
BenjaminMidtvedt:torch-backend
May 21, 2025
Merged

Torch backend#326
giovannivolpe merged 8 commits into
DeepTrackAI:torch-backendfrom
BenjaminMidtvedt:torch-backend

Conversation

@BenjaminMidtvedt
Copy link
Copy Markdown
Collaborator

@BenjaminMidtvedt BenjaminMidtvedt commented May 4, 2025

This is the proposal for enabling the use of a torch backend. The class / method names are not at all set in stone. In fact, I would very much appreciate feedback.

The core idea is as follows:

Internally, we use a special xp module instead of numpy. xp is based on array-api-compat, which is python library that provided a unified api to operate on arrays, based on the array-api standard (which is what numpy uses, and is the most established standard).

xp provides two core functionalities for us. First, as mentioned, it provides a unified interface for torch and numpy arrays (as well as cupy, dask and jax). This means that as long as we use xp everywhere, we will be cross-compatible. We do not have to change the code path based on the array provider.

Second, I've set it up to automatically create arrays of the type defined by a backend options in the deeptrack config object. So, if you call config.set_backend_numpy(), then xp.zeros((100, 100)) will create a numpy array, and if you call config.set_backend_torch(), the same code will create a torch tensor.

Feature.numpy(), Feature.torch()

there are two methods, f.numpy() and f.torch(), these can be used to override the global backend choice for a specific pipieline. This means that you can use different backends in different parts of the full pipeline. Note, however, that there is no automatic conversion between the backend (passing a tensor to a numpy feature will give an error). This is a choice made for simplicity and performance.

Device

xp array creation can be provided a device argument. For numpy, the only options are "cpu" and None, while for torch, all device options are available (both strings, or torch.device objects). I propose that each feature has a device flag that is set similarly to the backend. We will have to ensure that any array creation calls pass this flag: xp.zeros((100, 100), device=self.device)).

Again, I propose that we do not do automatic device transfers, and instead leave this to the user. However, we could provide a simple method to do so on the Feature class:

foo = foo_pipeline.torch().cuda()
bar = bar_pipeline.numpy().cpu()
foobar = foo.output(device="cpu", backend="numpy") >> bar

Examples

I have in this branch two examples in the root. These should be removed or moved to a better place before the PR is accepted. They demonstrate a minimal example of how this works, and how the torch backend could be used in a gradient descent setting.

Testing

A good starting point for testing this new backend would be to run the existing suite of tests with the torch backend.
I propose that each feature is additionally tested to ensure the output array type is of the right backend.
I also propose that each feature is additionally tested to ensure that gradients flow through the object. This should include all numeric properties as well. Full gradient compatibility is likely to take more time, and I propose that full compatibility is not required on first release.

Users

Users are not required to use xp for their own features. It is only required for cross-compatibility between backends. As such, old features written using numpy will still work identically.

@BenjaminMidtvedt BenjaminMidtvedt marked this pull request as draft May 4, 2025 14:30
@giovannivolpe giovannivolpe changed the base branch from develop to torch-backend May 21, 2025 12:46
@giovannivolpe giovannivolpe marked this pull request as ready for review May 21, 2025 12:47
@giovannivolpe giovannivolpe merged commit e4df470 into DeepTrackAI:torch-backend May 21, 2025
1 of 25 checks passed
giovannivolpe added a commit that referenced this pull request May 28, 2025
…ce in features

* Torch backend (#326)

* support variable computation backend

* standardize random array creation

* add some examples

* add BackendContext to manage backend settings in Config class

* export xp from backend

* add torch() and numpy() methods

* refactor Feature class to use backend context for execution

* store backend on initialization in Feature class

* u

* Update requirements.txt

* Remove CuPy from all code

* remove cupy top level

* remove deprecated enable/disable gpu

* Refactor Config class to remove cupy backend support

* remove functionality to handle cupy arrays

* Remove cupy array handling from elementwise tests

* Refactor Image class to remove CuPy support and clean up related methods.

* Remove unused CuPy import from optics.py

* Remove CuPy array handling from harmonics function in mie.py

* Some code cleanup and comments to make it understandable (#337)

Co-authored-by: Giovanni Volpe <46021832+giovannivolpe@users.noreply.github.com>

* Bm/import-torch-lazy (#335)

* Only import torch on use

* formatting

* Remove unused import of torch in _config.py

---------

Co-authored-by: Giovanni Volpe <46021832+giovannivolpe@users.noreply.github.com>

* Update 01. Using xp.ipynb

* Update _config.py

* Update 02. Gradients.ipynb

* Add torch array function dispatching (#338)

* Update 01. Using xp.ipynb

* small edits

* Update _config.py

* Update _config.py

* Update _config.py

* Update _config.py

* Update DTDV411_style.ipynb

* Update DTDV411_style.ipynb

* Update DTDV401_overview.ipynb

* Update DTDV401_overview.ipynb

* Bm/implement-device-dtype (#340)

* get dtype from xp module

* implement dtypes and device on Feature

* import xp module

* Refactor backend initialization in _Proxy class to use set_backend method and update backend_info retrieval

* Update import statement in features.py to include Literal type

* Bm/implement-math (#341)

* update types to allow torch arrays

* implement xp math v1

* get dtype from xp module

* implement dtypes and device on Feature

* import xp module

* implement math with xp 2

* Add VSCode settings for Python unittest configuration

* Refactor backend initialization in _Proxy class to use set_backend method and update backend_info retrieval

* Update import statement in features.py to include Literal type

* Update type alias for ArrayLike to use string reference for Image

* Update VSCode settings to simplify unittestArgs configuration

* remove lazy imports pointing to nothing

* Refactor Average feature calculation to stack images before computing mean

* update imports to use non-relative paths

* Add BackendTestBase

* Remove unused backend mixin classes from BackendTestBase

* Refactor math tests to use backend-specific implementations and extend test coverage for numpy and torch backends

* Update test_math.py

* Update __init__.py

* Delete settings.json

* Update DTDV401_overview.ipynb

* Update test_math.py

* u

* Update _config.py

* Update _config.py

* Update _config.py

* Update _config.py

* Create test__config.py

* Update test__config.py

* Update test__config.py

* Update test__config.py

* Update _config.py

* Update test__config.py

* Update test__config.py

* Update DTDV411_style.ipynb

* Update _config.py

* Update _config.py

* Update test__config.py

* Update test__config.py

* Update _config.py

* Update test__config.py

* Update test__config.py

* Update _config.py

* u

* Update README.md

* Update DTDV421_backends.ipynb

* Update DTAT399F_backend._config.ipynb

* Update DTAT399F_backend._config.ipynb

* Update __init__.py

* Update DTAT399F_backend._config.ipynb

* Update types.py

* Delete 02. Gradients.ipynb

* Create DTGS161_torch_fitting.ipynb

* Update README.md

* u

* Update README.md

* Update README.md

* Update features.py

* Update DTGS161_torch_fitting.ipynb

---------

Co-authored-by: BenjaminMidtvedt <41636530+BenjaminMidtvedt@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants