Skip to content

[CodeCamp #15] Add sigmoid focal loss cpu impl #2536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ We implement common ops used in detection, segmentation, etc.
| RoIAlign | √ | √ | √ | | |
| RoIAwarePool3d | | √ | √ | | |
| SAConv2d | | √ | | | |
| SigmoidFocalLoss | | √ | √ | | √ |
| SigmoidFocalLoss | | √ | √ | | √ |
| SoftmaxFocalLoss | | √ | | | √ |
| SoftNMS | | √ | | | |
| Sparse Convolution | | √ | | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| RoIAlign | √ | √ | √ | | |
| RoIAwarePool3d | | √ | √ | | |
| SAConv2d | | √ | | | |
| SigmoidFocalLoss | | √ | √ | | √ |
| SigmoidFocalLoss | | √ | √ | | √ |
| SoftmaxFocalLoss | | √ | | | √ |
| SoftNMS | | √ | | | |
| Sparse Convolution | | √ | | | |
Expand Down
92 changes: 92 additions & 0 deletions mmcv/ops/csrc/pytorch/cpu/sigmoid_focal_loss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright(c) OpenMMLab.All rights reserved.
#include <cfloat>
#include <cmath>
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

template <typename T>
void sigmoid_focal_loss_forward_cpu_kernel(const int N, const T* input, const int64_t* target,
const T* weight, T* output,
const float gamma,
const float alpha,
const int num_classes) {
for (int i = 0; i < N; i++) {
T p = (T) 1. / ((T)1. + exp(-input[i]));
int64_t t = target[i / num_classes];
if (t == (i % num_classes)) { // positive
output[i] = -alpha * pow(((T)1. - p), gamma) * log(std::max(p, (T)FLT_MIN));
}
else {
output[i] = -((T)1. - alpha) * pow(p, gamma) * log(std::max((T)1. - p, (T)FLT_MIN));
}
if (weight != NULL) {
output[i] *= weight[t];
}
}
}

template <typename T>
void sigmoid_focal_loss_backward_cpu_kernel(const int N, const T* input, const int64_t* target,
const T* weight, T* grad_input,
const float gamma,
const float alpha,
const int num_classes) {
for (int i = 0; i < N; i++) {
T p = (T) 1. / ((T)1. + exp(-input[i]));
int64_t t = target[i / num_classes];
if (t == (i % num_classes)) { // positive
grad_input[i] = -alpha * pow((T)1. - p, gamma) *
((T)1. - p - (gamma * p * log(std::max(p, (T)FLT_MIN))));
}
else {
grad_input[i] = -((T)1. - alpha) * pow(p, gamma) *
(gamma * ((T)1. - p) * log(std::max((T)1. - p, (T)FLT_MIN)) - p);
}
if (weight != NULL) {
grad_input[i] *= weight[t];
}
}
}

void TensorSigmoidFocalLossForwardCPUKernelLaucher(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = input.size(1);
AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sigmoid_focal_loss_forward_cpu_kernel", [&] {
sigmoid_focal_loss_forward_cpu_kernel(
output_size, input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
}

void TensorSigmoidFocalLossBackwardCPUKernelLaucher(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
const float gamma,
const float alpha) {
int output_size = grad_input.numel();
int num_classes = input.size(1);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sigmoid_focal_loss_backward_cpu_kernel", [&] {
sigmoid_focal_loss_backward_cpu_kernel<scalar_t>(
output_size, input.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha);
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha);

REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, CPU, TensorSigmoidFocalLossForwardCPUKernelLaucher);
REGISTER_DEVICE_IMPL(sigmoid_focal_loss_backward_impl, CPU, TensorSigmoidFocalLossBackwardCPUKernelLaucher);
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ onnxoptimizer; python_version < '3.10'
onnxruntime>=1.8.0
protobuf~=3.19.0
pytest
pytest-benchmark
PyTurboJPEG
scipy
tifffile
63 changes: 63 additions & 0 deletions tests/test_ops/test_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_USING_PARROTS = False

# torch.set_printoptions(precision=8, threshold=100)
SKIP_BENCHMARK = True

inputs = [
([[1., 0], [0, 1.]], [0, 1]),
Expand All @@ -36,6 +37,12 @@
[0.07457211, 0.07457669, -0.02483728],
[-0.02462499, 0.08277918, 0.18050370]])]

benchmark_data = [{
'name': f'{2 ** batch_exponent}_{2 ** num_classes_exponent}',
'batch': 2**batch_exponent,
'num_classes': 2**num_classes_exponent
} for batch_exponent in range(1, 14) for num_classes_exponent in range(1, 13)]


class Testfocalloss:

Expand Down Expand Up @@ -130,6 +137,7 @@ def test_softmax_half(self):
self._test_softmax(dtype=torch.half)

@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'npu',
marks=pytest.mark.skipif(
Expand All @@ -147,6 +155,7 @@ def test_sigmoid_float(self, device):
self._test_sigmoid(device=device, dtype=torch.float)

@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'npu',
marks=pytest.mark.skipif(
Expand All @@ -168,3 +177,57 @@ def test_grad_softmax_float(self):

def test_grad_sigmoid_float(self):
self._test_grad_sigmoid(dtype=torch.float)

def _test_mmcv_cpu_sigmoid_focal_loss(self, args):
from mmcv.ops import sigmoid_focal_loss
loss = sigmoid_focal_loss(*args)
loss.backward()

def _test_torchvision_cpu_sigmoid_focal_loss(self, args):
from torchvision.ops import sigmoid_focal_loss
loss = sigmoid_focal_loss(*args)
loss.backward()

@pytest.mark.skipif(SKIP_BENCHMARK, reason='Skip benchmark.')
@pytest.mark.parametrize(
'param',
benchmark_data,
ids=[f"mmcv_{item['name']}" for item in benchmark_data])
def test_mmcv_cpu_sigmoid_focal_loss_benchmark(self, param, benchmark):
batch, num_classes = param['batch'], param['num_classes']
device = 'cpu'
dtype = torch.float
alpha = 0.25
gamma = 2.0
np_x = np.random.rand(batch, num_classes)
np_y = np.random.randint(0, num_classes, batch)

x = torch.from_numpy(np_x).to(device).type(dtype)
x.requires_grad_()
y = torch.from_numpy(np_y).to(device).long()
benchmark.pedantic(
self._test_mmcv_cpu_sigmoid_focal_loss,
((x, y, gamma, alpha, None, 'mean'), ),
rounds=10)

@pytest.mark.skipif(SKIP_BENCHMARK, reason='Skip benchmark.')
@pytest.mark.parametrize(
'param',
benchmark_data,
ids=[f"torch_{item['name']}" for item in benchmark_data])
def test_torch_cpu_sigmoid_focal_loss_benchmask(self, param, benchmark):
batch, num_classes = param['batch'], param['num_classes']
device = 'cpu'
dtype = torch.float
alpha = 0.25
gamma = 2.0
np_x = np.random.rand(batch, num_classes)
np_y = np.random.randint(0, 2, (batch, num_classes))

x = torch.from_numpy(np_x).to(device).type(dtype)
x.requires_grad_()
y = torch.from_numpy(np_y).to(device).type(dtype)
benchmark.pedantic(
self._test_torchvision_cpu_sigmoid_focal_loss,
((x, y, gamma, alpha, 'mean'), ),
rounds=10)