Conversation
…ment_linalg_solve_from_NumPy
|
Thank you for the PR! |
|
Thank you for the PR! |
|
Thank you for the PR! |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2113 +/- ##
==========================================
+ Coverage 91.63% 91.67% +0.04%
==========================================
Files 86 86
Lines 14005 14054 +49
==========================================
+ Hits 12833 12884 +51
+ Misses 1172 1170 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thank you for the PR! |
ClaudiaComito
left a comment
There was a problem hiding this comment.
Thanks a lot for addressing this @brownbaerchen . I have a few comments, see below. In general I'd rather avoid copy-pasting documentation from numpy or torch, although obviously they will be very similar. If we are copy-pasting, we should point out where from.
Also tagging @GioPede who requested the feature and might want to check it out as well.
heat/core/linalg/solver.py
Outdated
| # raise error if b is distributed in disallowed way | ||
| if b.is_distributed() and b.split == b_non_batched_axis: | ||
| raise ValueError( | ||
| f"b of shape {b.shape} with A of shape {A.shape} is split in {b.split} but may not be distributed in non-batched axis {b_non_batched_axis}" |
There was a problem hiding this comment.
This might be a bit obscure. How about "A and b are not distributed along the same batch axis: A.shape, A.split, b.shape, b.split...."
heat/core/linalg/solver.py
Outdated
| raise ValueError(f"Split of A and b must match, but got {A.split} and {b.split}") | ||
|
|
||
| # figure out what the output vector looks like | ||
| out_initalization = {"dtype": b.dtype, "device": b.device} |
There was a problem hiding this comment.
dtype should probably be types.promote_types(a.dtype, b.dtype)
|
Thank you for the PR! |
|
Thank you for the PR! |
…linalg_solve_from_NumPy
|
@ClaudiaComito you need to reapprove after I resolved the merge conflict from moving the tests. |
Due Diligence
Description
This function directly solves linear systems$Ax=b$ for $x$ . It supports batched solving and $A$ and $b$ may only be distributed in non-batched axes.
The function simply passes the local data to pytorch and makes sure the input and output make sense for heat.
Note that pytorch allows to pass$xA=b$ instead. I didn't implement this here because I wasn't sure if it's needed and didn't want to bother. If you want to see this here, I suggest splitting #2036 into sub issues and not doing this as part of this PR. @GioPede, please let us know if you need this. As
left=Falseintorch.linalg.solvein order to solvenumpy.linalg.solvedoes not have this option, I assumed you don't.I basically copy pasted the docstring from pytorch. I have a question @ClaudiaComito: I left some stuff in that is relevant not for the heat code, but only for the pytorch code that is called underneath. Specifically, I am thinking about the docstring including limitation of datatypes and raising
RuntimeErrorif the matrix is not invertible. I don't like this because if it changes in pytorch, our documentation is just lying. On the other hand, users don't want to have to look up the torch documentation for the heat function they are using. What is the policy for this in heat?Issue/s resolved: #2036
Changes proposed:
Type of change
Does this change modify the behaviour of other functions? If so, which?
no