Skip to content

Commit 1696cee

Browse files
committed
Prevent type=inferability escaping for rrule of sortslices
1 parent e055009 commit 1696cee

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/rulesets/Base/sort.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,13 @@ end
6060

6161
function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
6262
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
63-
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
6463
function sortslices_pullback(dy)
65-
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
64+
# avoid closing over `inds` as it doesn't fully infer and that makes it worse
65+
# recomputing is cheap
66+
inds_inner = ntuple(d -> d == dims ? p : (:), ndims(x))
67+
return (NoTangent(), ∇getindex(x, unthunk(dy), inds_inner...))
6668
end
69+
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
6770
return x[inds...], sortslices_pullback
6871
end
6972

test/rulesets/Base/sort.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
2828
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
29-
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)
29+
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum))
3030

3131
@test_throws Exception sortslices(Diagonal(1:3), dims=1)
3232
end

0 commit comments

Comments
 (0)