Skip to content

Commit eafcede

Browse files
authored
Merge pull request #9 from DaGaiBa/rc4main
Fix precision error of FocalLoss op
2 parents 2d6978e + ce612f7 commit eafcede

File tree

1 file changed

+66
-12
lines changed

1 file changed

+66
-12
lines changed

mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp

+66-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ using namespace std;
44

55
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
66
Tensor output, float gamma, float alpha) {
7+
at::Tensor input_y = input;
8+
at::Tensor output_y = output;
9+
bool is_half = input.scalar_type() == at::kHalf;
10+
if (is_half) {
11+
input_y = input.to(at::kFloat);
12+
output_y = output.to(at::kFloat);
13+
}
714
int64_t n_class = input.size(1);
815
at::Tensor target_y = at::ones_like(input);
916
if (n_class == 1) {
@@ -15,21 +22,28 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
1522
}
1623
target_y = target_y.to(at::kInt);
1724
int64_t weight_size = weight.size(0);
18-
at::Tensor weight_y = at::ones_like(input);
25+
at::Tensor weight_y = at::ones_like(input_y);
1926
if (weight_size > 0) {
2027
weight_y = at::broadcast_to(weight, input.sizes());
28+
if (is_half) {
29+
weight_y = weight_y.to(at::kFloat);
30+
}
2131
}
2232
OpCommand cmd;
2333
string reduction = "none";
2434
cmd.Name("SigmoidFocalLoss")
25-
.Input(input)
35+
.Input(input_y)
2636
.Input(target_y)
2737
.Input(weight_y)
28-
.Output(output)
38+
.Output(output_y)
2939
.Attr("gamma", gamma)
3040
.Attr("alpha", alpha)
3141
.Attr("reduction", reduction)
3242
.Run();
43+
if (is_half) {
44+
output_y = output_y.to(at::kHalf);
45+
}
46+
output.copy_(output_y);
3347
}
3448

3549
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
@@ -38,6 +52,13 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
3852
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
3953
Tensor grad_input, float gamma,
4054
float alpha) {
55+
at::Tensor input_y = input;
56+
at::Tensor grad_input_y = grad_input;
57+
bool is_half = input.scalar_type() == at::kHalf;
58+
if (is_half) {
59+
input_y = input.to(at::kFloat);
60+
grad_input_y = grad_input.to(at::kFloat);
61+
}
4162
int64_t n_class = input.size(1);
4263
at::Tensor target_y = at::ones_like(input);
4364
if (n_class == 1) {
@@ -50,22 +71,29 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
5071
target_y = target_y.to(at::kInt);
5172
at::Tensor grad_up = at::ones_like(input);
5273
int64_t weight_size = weight.size(0);
53-
at::Tensor weight_y = at::ones_like(input);
74+
at::Tensor weight_y = at::ones_like(input_y);
5475
if (weight_size > 0) {
5576
weight_y = at::broadcast_to(weight, input.sizes());
77+
if (is_half) {
78+
weight_y = weight_y.to(at::kFloat);
79+
}
5680
}
5781
OpCommand cmd;
5882
string reduction = "none";
5983
cmd.Name("SigmoidFocalLossGrad")
60-
.Input(input)
84+
.Input(input_y)
6185
.Input(target_y)
6286
.Input(grad_up)
6387
.Input(weight_y)
64-
.Output(grad_input)
88+
.Output(grad_input_y)
6589
.Attr("gamma", gamma)
6690
.Attr("alpha", alpha)
6791
.Attr("reduction", reduction)
6892
.Run();
93+
if (is_half) {
94+
grad_input_y = grad_input_y.to(at::kHalf);
95+
}
96+
grad_input.copy_(grad_input_y);
6997
}
7098

7199
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
@@ -74,26 +102,38 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
74102

75103
void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
76104
Tensor output, float gamma, float alpha) {
105+
at::Tensor input_y = input;
106+
bool is_half = input.scalar_type() == at::kHalf;
107+
if (is_half) {
108+
input_y = input.to(at::kFloat);
109+
}
77110
int64_t n_class = input.size(1);
78111
at::Tensor target_y = at::one_hot(target, n_class);
79112
target_y = target_y.to(at::kInt);
80113
int64_t weight_size = weight.size(0);
81-
at::Tensor weight_y = at::ones_like(input);
114+
at::Tensor weight_y = at::ones_like(input_y);
82115
if (weight_size > 0) {
83116
weight_y = at::broadcast_to(weight, input.sizes());
117+
if (is_half) {
118+
weight_y = weight_y.to(at::kFloat);
119+
}
84120
}
85-
at::Tensor op_output = at::ones_like(input);
121+
122+
at::Tensor op_output = at::ones_like(input_y);
86123
OpCommand cmd;
87124
string reduction = "none";
88125
cmd.Name("SoftmaxFocalLoss")
89-
.Input(input)
126+
.Input(input_y)
90127
.Input(target_y)
91128
.Input(weight_y)
92129
.Output(op_output)
93130
.Attr("gamma", gamma)
94131
.Attr("alpha", alpha)
95132
.Attr("reduction", reduction)
96133
.Run();
134+
if (is_half) {
135+
op_output = op_output.to(at::kHalf);
136+
}
97137
int64_t n_batch = input.size(0);
98138
c10::SmallVector<int64_t, 2> offsets = {0, 0};
99139
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
@@ -124,27 +164,41 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
124164
void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
125165
Tensor buff, Tensor grad_input,
126166
float gamma, float alpha) {
167+
at::Tensor input_y = input;
168+
at::Tensor grad_input_y = grad_input;
169+
bool is_half = input.scalar_type() == at::kHalf;
170+
if (is_half) {
171+
input_y = input.to(at::kFloat);
172+
grad_input_y = grad_input.to(at::kFloat);
173+
}
127174
int64_t n_class = input.size(1);
128175
at::Tensor target_y = at::one_hot(target, n_class);
129176
target_y = target_y.to(at::kInt);
130177
at::Tensor grad_up = at::ones_like(input);
131178
int64_t weight_size = weight.size(0);
132-
at::Tensor weight_y = at::ones_like(input);
179+
at::Tensor weight_y = at::ones_like(input_y);
133180
if (weight_size > 0) {
134181
weight_y = at::broadcast_to(weight, input.sizes());
182+
if (is_half) {
183+
weight_y = weight_y.to(at::kFloat);
184+
}
135185
}
136186
OpCommand cmd;
137187
string reduction = "none";
138188
cmd.Name("SoftmaxFocalLossGrad")
139-
.Input(input)
189+
.Input(input_y)
140190
.Input(target_y)
141191
.Input(grad_up)
142192
.Input(weight_y)
143-
.Output(grad_input)
193+
.Output(grad_input_y)
144194
.Attr("gamma", gamma)
145195
.Attr("alpha", alpha)
146196
.Attr("reduction", reduction)
147197
.Run();
198+
if (is_half) {
199+
grad_input_y = grad_input_y.to(at::kHalf);
200+
}
201+
grad_input.copy_(grad_input_y);
148202
}
149203

150204
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,

0 commit comments

Comments
 (0)