Torch backend#326
Merged
giovannivolpe merged 8 commits intoMay 21, 2025
Merged
Conversation
e4df470
into
DeepTrackAI:torch-backend
1 of 25 checks passed
giovannivolpe
added a commit
that referenced
this pull request
May 28, 2025
…ce in features * Torch backend (#326) * support variable computation backend * standardize random array creation * add some examples * add BackendContext to manage backend settings in Config class * export xp from backend * add torch() and numpy() methods * refactor Feature class to use backend context for execution * store backend on initialization in Feature class * u * Update requirements.txt * Remove CuPy from all code * remove cupy top level * remove deprecated enable/disable gpu * Refactor Config class to remove cupy backend support * remove functionality to handle cupy arrays * Remove cupy array handling from elementwise tests * Refactor Image class to remove CuPy support and clean up related methods. * Remove unused CuPy import from optics.py * Remove CuPy array handling from harmonics function in mie.py * Some code cleanup and comments to make it understandable (#337) Co-authored-by: Giovanni Volpe <46021832+giovannivolpe@users.noreply.github.com> * Bm/import-torch-lazy (#335) * Only import torch on use * formatting * Remove unused import of torch in _config.py --------- Co-authored-by: Giovanni Volpe <46021832+giovannivolpe@users.noreply.github.com> * Update 01. Using xp.ipynb * Update _config.py * Update 02. Gradients.ipynb * Add torch array function dispatching (#338) * Update 01. Using xp.ipynb * small edits * Update _config.py * Update _config.py * Update _config.py * Update _config.py * Update DTDV411_style.ipynb * Update DTDV411_style.ipynb * Update DTDV401_overview.ipynb * Update DTDV401_overview.ipynb * Bm/implement-device-dtype (#340) * get dtype from xp module * implement dtypes and device on Feature * import xp module * Refactor backend initialization in _Proxy class to use set_backend method and update backend_info retrieval * Update import statement in features.py to include Literal type * Bm/implement-math (#341) * update types to allow torch arrays * implement xp math v1 * get dtype from xp module * implement dtypes and device on Feature * import xp module * implement math with xp 2 * Add VSCode settings for Python unittest configuration * Refactor backend initialization in _Proxy class to use set_backend method and update backend_info retrieval * Update import statement in features.py to include Literal type * Update type alias for ArrayLike to use string reference for Image * Update VSCode settings to simplify unittestArgs configuration * remove lazy imports pointing to nothing * Refactor Average feature calculation to stack images before computing mean * update imports to use non-relative paths * Add BackendTestBase * Remove unused backend mixin classes from BackendTestBase * Refactor math tests to use backend-specific implementations and extend test coverage for numpy and torch backends * Update test_math.py * Update __init__.py * Delete settings.json * Update DTDV401_overview.ipynb * Update test_math.py * u * Update _config.py * Update _config.py * Update _config.py * Update _config.py * Create test__config.py * Update test__config.py * Update test__config.py * Update test__config.py * Update _config.py * Update test__config.py * Update test__config.py * Update DTDV411_style.ipynb * Update _config.py * Update _config.py * Update test__config.py * Update test__config.py * Update _config.py * Update test__config.py * Update test__config.py * Update _config.py * u * Update README.md * Update DTDV421_backends.ipynb * Update DTAT399F_backend._config.ipynb * Update DTAT399F_backend._config.ipynb * Update __init__.py * Update DTAT399F_backend._config.ipynb * Update types.py * Delete 02. Gradients.ipynb * Create DTGS161_torch_fitting.ipynb * Update README.md * u * Update README.md * Update README.md * Update features.py * Update DTGS161_torch_fitting.ipynb --------- Co-authored-by: BenjaminMidtvedt <41636530+BenjaminMidtvedt@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is the proposal for enabling the use of a torch backend. The class / method names are not at all set in stone. In fact, I would very much appreciate feedback.
The core idea is as follows:
Internally, we use a special
xpmodule instead ofnumpy.xpis based onarray-api-compat, which is python library that provided a unified api to operate on arrays, based on thearray-apistandard (which is what numpy uses, and is the most established standard).xpprovides two core functionalities for us. First, as mentioned, it provides a unified interface for torch and numpy arrays (as well as cupy, dask and jax). This means that as long as we usexpeverywhere, we will be cross-compatible. We do not have to change the code path based on the array provider.Second, I've set it up to automatically create arrays of the type defined by a
backendoptions in the deeptrack config object. So, if you callconfig.set_backend_numpy(), thenxp.zeros((100, 100))will create a numpy array, and if you callconfig.set_backend_torch(), the same code will create a torch tensor.Feature.numpy(), Feature.torch()
there are two methods,
f.numpy()andf.torch(), these can be used to override the global backend choice for a specific pipieline. This means that you can use different backends in different parts of the full pipeline. Note, however, that there is no automatic conversion between the backend (passing a tensor to a numpy feature will give an error). This is a choice made for simplicity and performance.Device
xparray creation can be provided adeviceargument. For numpy, the only options are "cpu" and None, while fortorch, all device options are available (both strings, or torch.device objects). I propose that each feature has a device flag that is set similarly to the backend. We will have to ensure that any array creation calls pass this flag:xp.zeros((100, 100), device=self.device)).Again, I propose that we do not do automatic device transfers, and instead leave this to the user. However, we could provide a simple method to do so on the Feature class:
Examples
I have in this branch two examples in the root. These should be removed or moved to a better place before the PR is accepted. They demonstrate a minimal example of how this works, and how the torch backend could be used in a gradient descent setting.
Testing
A good starting point for testing this new backend would be to run the existing suite of tests with the torch backend.
I propose that each feature is additionally tested to ensure the output array type is of the right backend.
I also propose that each feature is additionally tested to ensure that gradients flow through the object. This should include all numeric properties as well. Full gradient compatibility is likely to take more time, and I propose that full compatibility is not required on first release.
Users
Users are not required to use
xpfor their own features. It is only required for cross-compatibility between backends. As such, old features written using numpy will still work identically.