Skip to content

Commit 5b45bfa

Browse files
authored
Merge pull request #31 from teddykoker/half
Add fp16 support
2 parents 27b4ce5 + 82a9759 commit 5b45bfa

4 files changed

Lines changed: 27 additions & 11 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def ext_modules():
5151

5252
setup(
5353
name="torchsort",
54-
version="0.1.6",
54+
version="0.1.7",
5555
description="Differentiable sorting and ranking in PyTorch",
5656
author="Teddy Koker",
5757
url="https://github.com/teddykoker/torchsort",

tests/test_ops.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
REGULARIZATION = ["l2", "kl"]
1616
REGULARIZATION_STRENGTH = [1e-1, 1e0, 1e1]
1717

18-
DEVICES = (
19-
[torch.device("cpu")] + ([torch.device("cuda")] if torch.cuda.is_available() else [])
18+
DEVICES = [torch.device("cpu")] + (
19+
[torch.device("cuda")] if torch.cuda.is_available() else []
2020
)
2121

2222
torch.manual_seed(0)
@@ -59,3 +59,19 @@ def test_vs_original(funcs, regularization, regularization_strength, device):
5959
funcs[0](x, **kwargs).cpu(),
6060
funcs[1](x.cpu(), **kwargs),
6161
)
62+
63+
64+
@pytest.mark.parametrize("function", [soft_rank, soft_sort])
65+
@pytest.mark.parametrize("regularization", REGULARIZATION)
66+
@pytest.mark.parametrize("regularization_strength", REGULARIZATION_STRENGTH)
67+
@pytest.mark.parametrize("device", DEVICES)
68+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to test fp16")
69+
def test_half(function, regularization, regularization_strength, device):
70+
x = torch.randn(BATCH_SIZE, SEQ_LEN, requires_grad=True).cuda().half()
71+
f = partial(
72+
function,
73+
regularization=regularization,
74+
regularization_strength=regularization_strength,
75+
)
76+
# don't think theres a better way of testing, tolerance must be pretty high
77+
assert torch.allclose(f(x), f(x.float()).half(), atol=1e-1)

torchsort/isotonic_cpu.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
288288
auto target = torch::zeros_like(y);
289289
auto c = torch::zeros_like(y);
290290

291-
AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] {
291+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] {
292292
isotonic_l2_kernel<scalar_t>(
293293
y.accessor<scalar_t, 2>(),
294294
sol.accessor<scalar_t, 2>(),
@@ -311,7 +311,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
311311
auto lse_w_ = torch::zeros_like(y);
312312
auto target = torch::zeros_like(y);
313313

314-
AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] {
314+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] {
315315
isotonic_kl_kernel<scalar_t>(
316316
y.accessor<scalar_t, 2>(),
317317
w.accessor<scalar_t, 2>(),
@@ -330,7 +330,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
330330
auto n = sol.size(1);
331331
auto ret = torch::zeros_like(sol);
332332

333-
AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] {
333+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] {
334334
isotonic_l2_backward_kernel<scalar_t>(
335335
s.accessor<scalar_t, 2>(),
336336
sol.accessor<scalar_t, 2>(),
@@ -347,7 +347,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te
347347
auto n = sol.size(1);
348348
auto ret = torch::zeros_like(sol);
349349

350-
AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] {
350+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] {
351351
isotonic_kl_backward_kernel<scalar_t>(
352352
s.accessor<scalar_t, 2>(),
353353
sol.accessor<scalar_t, 2>(),

torchsort/isotonic_cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
316316
const int threads = 1024;
317317
const int blocks = (batch + threads - 1) / threads;
318318

319-
AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] {
319+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] {
320320
isotonic_l2_kernel<scalar_t><<<blocks, threads>>>(
321321
y.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
322322
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
@@ -342,7 +342,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
342342
const int threads = 1024;
343343
const int blocks = (batch + threads - 1) / threads;
344344

345-
AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] {
345+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] {
346346
isotonic_kl_kernel<scalar_t><<<blocks, threads>>>(
347347
y.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
348348
w.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
@@ -365,7 +365,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
365365
const int threads = 1024;
366366
const int blocks = (batch + threads - 1) / threads;
367367

368-
AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] {
368+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] {
369369
isotonic_l2_backward_kernel<scalar_t><<<blocks, threads>>>(
370370
s.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
371371
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
@@ -387,7 +387,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te
387387
const int threads = 1024;
388388
const int blocks = (batch + threads - 1) / threads;
389389

390-
AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] {
390+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] {
391391
isotonic_kl_backward_kernel<scalar_t><<<blocks, threads>>>(
392392
s.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
393393
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),

0 commit comments

Comments
 (0)