Skip to content

Extend autograd support#30

Draft
askorikov wants to merge 5 commits intoahendriksen:masterfrom
askorikov:extend_autograd_support
Draft

Extend autograd support#30
askorikov wants to merge 5 commits intoahendriksen:masterfrom
askorikov:extend_autograd_support

Conversation

@askorikov
Copy link
Copy Markdown
Contributor

This PR adds support for forward-mode and higher-order differentiation in PyTorch autograd engine. Additionally, it enables integration with torch.func, which implements JAX-style composable functional transforms in PyTorch, allowing for a flexible and elegant use of different types of automatic differentiation (+vectorizing map functionality also implemented in this PR).

Integration with torch.func breaks compatibility with PyTorch < 2.0, however, so we need to estimate if it's safe to drop support for it at this moment.

To do:

  • Check for use cases requiring PyTorch < 2.0
  • Add tests

With the ambition to enable forward-mode and higher-order differentiation support, checking `input.requires_grad` is not sufficient to determine if we will need the relevant arguments in the future.
WARNING:  this breaks compatibility with PyTorch < 2.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant