Skip to content

Commit 9df7226

Browse files
authored
Merge pull request #1493 from lkdvos/ld/sort
2 parents c1d82be + 03a8ef7 commit 9df7226

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

src/lib/array.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,20 +251,6 @@ end
251251

252252
@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing
253253

254-
@adjoint function sort(x::AbstractArray; by=identity)
255-
p = sortperm(x, by=by)
256-
return x[p], x̄ -> (x̄[invperm(p)],)
257-
end
258-
259-
@adjoint function filter(f, x::AbstractVector)
260-
t = map(f, x)
261-
x[t], Δ -> begin
262-
dx = _zero(x, eltype(Δ))
263-
dx[t] .= Δ
264-
(nothing, dx)
265-
end
266-
end
267-
268254
# Iterators
269255

270256
@adjoint function enumerate(xs)

test/gradcheck.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,13 +425,17 @@ end
425425
[2,3,1],
426426
[1, 2, 3],
427427
[1,2,3],
428-
[2,1,3]
428+
[2,1,3],
429+
[1,3,2],
430+
[3,2,1]
429431
]
430432
for i = 1:3
431433
@test gradient(v->sort(v)[i], [3.,1,2])[1][correct[1][i]] == 1
432434
@test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1
433435
@test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
434436
@test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
437+
@test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
438+
@test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
435439
end
436440
end
437441

0 commit comments

Comments
 (0)