Fix broadcasts which are type unstable with Dual numbers#1441
Fix broadcasts which are type unstable with Dual numbers#1441DomCRose wants to merge 7 commits intoFluxML:masterfrom
Conversation
|
I'm traveling right now, but I should be able to look at this in a couple of days. I think all the changes make sense and I'm sorry for missing this in the original PR. |
|
Great, thanks. No rush.
Certainly. I don't believe this breaks anything, I reverted my attempt to merge the functions further when I realized it broke GPU compilation (in a real -> real case actually, not even complex). |
Closes #1439.
This is acheived by moving a dispatch from the element type of the output of broadcasting, to each individual element within the pullback, along with ensuring non-concrete element type arrays take the same path as concrete Dual arrays. This should hopefully compile away when the eltype is concrete, and indeed some simple benchmarks show no loss of performance, but I've hardly been exhaustive.
I've also added a few tests covering various real / complex input / output combinations, and a specific case that produced errors before rather than silently failing.
While its nice that this works, it could be worth adding a note to the documentation about the performance of broadcasting which has a type stable forward pass but becomes type unstable on Dual inputs, and perhaps that likewise such Dual input stability is required for the code to work on the GPU. But I'm not sure where that could go.
Edit: just to note, I did also try merging the complex and real input branches into one function that dispatched according to the argument type, but while it worked well on CPU this seemed to stump the GPU compiler on some cases for reasons I don't understand.
PR Checklist