@@ -4,6 +4,13 @@ using namespace std;
4
4
5
5
void sigmoid_focal_loss_forward_npu (Tensor input, Tensor target, Tensor weight,
6
6
Tensor output, float gamma, float alpha) {
7
+ at::Tensor input_y = input;
8
+ at::Tensor output_y = output;
9
+ bool is_half = input.scalar_type () == at::kHalf ;
10
+ if (is_half) {
11
+ input_y = input.to (at::kFloat );
12
+ output_y = output.to (at::kFloat );
13
+ }
7
14
int64_t n_class = input.size (1 );
8
15
at::Tensor target_y = at::ones_like (input);
9
16
if (n_class == 1 ) {
@@ -15,21 +22,28 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
15
22
}
16
23
target_y = target_y.to (at::kInt );
17
24
int64_t weight_size = weight.size (0 );
18
- at::Tensor weight_y = at::ones_like (input );
25
+ at::Tensor weight_y = at::ones_like (input_y );
19
26
if (weight_size > 0 ) {
20
27
weight_y = at::broadcast_to (weight, input.sizes ());
28
+ if (is_half) {
29
+ weight_y = weight_y.to (at::kFloat );
30
+ }
21
31
}
22
32
OpCommand cmd;
23
33
string reduction = " none" ;
24
34
cmd.Name (" SigmoidFocalLoss" )
25
- .Input (input )
35
+ .Input (input_y )
26
36
.Input (target_y)
27
37
.Input (weight_y)
28
- .Output (output )
38
+ .Output (output_y )
29
39
.Attr (" gamma" , gamma )
30
40
.Attr (" alpha" , alpha)
31
41
.Attr (" reduction" , reduction)
32
42
.Run ();
43
+ if (is_half) {
44
+ output_y = output_y.to (at::kHalf );
45
+ }
46
+ output.copy_ (output_y);
33
47
}
34
48
35
49
void sigmoid_focal_loss_forward_impl (Tensor input, Tensor target, Tensor weight,
@@ -38,6 +52,13 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
38
52
void sigmoid_focal_loss_backward_npu (Tensor input, Tensor target, Tensor weight,
39
53
Tensor grad_input, float gamma,
40
54
float alpha) {
55
+ at::Tensor input_y = input;
56
+ at::Tensor grad_input_y = grad_input;
57
+ bool is_half = input.scalar_type () == at::kHalf ;
58
+ if (is_half) {
59
+ input_y = input.to (at::kFloat );
60
+ grad_input_y = grad_input.to (at::kFloat );
61
+ }
41
62
int64_t n_class = input.size (1 );
42
63
at::Tensor target_y = at::ones_like (input);
43
64
if (n_class == 1 ) {
@@ -50,22 +71,29 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
50
71
target_y = target_y.to (at::kInt );
51
72
at::Tensor grad_up = at::ones_like (input);
52
73
int64_t weight_size = weight.size (0 );
53
- at::Tensor weight_y = at::ones_like (input );
74
+ at::Tensor weight_y = at::ones_like (input_y );
54
75
if (weight_size > 0 ) {
55
76
weight_y = at::broadcast_to (weight, input.sizes ());
77
+ if (is_half) {
78
+ weight_y = weight_y.to (at::kFloat );
79
+ }
56
80
}
57
81
OpCommand cmd;
58
82
string reduction = " none" ;
59
83
cmd.Name (" SigmoidFocalLossGrad" )
60
- .Input (input )
84
+ .Input (input_y )
61
85
.Input (target_y)
62
86
.Input (grad_up)
63
87
.Input (weight_y)
64
- .Output (grad_input )
88
+ .Output (grad_input_y )
65
89
.Attr (" gamma" , gamma )
66
90
.Attr (" alpha" , alpha)
67
91
.Attr (" reduction" , reduction)
68
92
.Run ();
93
+ if (is_half) {
94
+ grad_input_y = grad_input_y.to (at::kHalf );
95
+ }
96
+ grad_input.copy_ (grad_input_y);
69
97
}
70
98
71
99
void sigmoid_focal_loss_backward_impl (Tensor input, Tensor target,
@@ -74,26 +102,38 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
74
102
75
103
void softmax_focal_loss_forward_npu (Tensor input, Tensor target, Tensor weight,
76
104
Tensor output, float gamma, float alpha) {
105
+ at::Tensor input_y = input;
106
+ bool is_half = input.scalar_type () == at::kHalf ;
107
+ if (is_half) {
108
+ input_y = input.to (at::kFloat );
109
+ }
77
110
int64_t n_class = input.size (1 );
78
111
at::Tensor target_y = at::one_hot (target, n_class);
79
112
target_y = target_y.to (at::kInt );
80
113
int64_t weight_size = weight.size (0 );
81
- at::Tensor weight_y = at::ones_like (input );
114
+ at::Tensor weight_y = at::ones_like (input_y );
82
115
if (weight_size > 0 ) {
83
116
weight_y = at::broadcast_to (weight, input.sizes ());
117
+ if (is_half) {
118
+ weight_y = weight_y.to (at::kFloat );
119
+ }
84
120
}
85
- at::Tensor op_output = at::ones_like (input);
121
+
122
+ at::Tensor op_output = at::ones_like (input_y);
86
123
OpCommand cmd;
87
124
string reduction = " none" ;
88
125
cmd.Name (" SoftmaxFocalLoss" )
89
- .Input (input )
126
+ .Input (input_y )
90
127
.Input (target_y)
91
128
.Input (weight_y)
92
129
.Output (op_output)
93
130
.Attr (" gamma" , gamma )
94
131
.Attr (" alpha" , alpha)
95
132
.Attr (" reduction" , reduction)
96
133
.Run ();
134
+ if (is_half) {
135
+ op_output = op_output.to (at::kHalf );
136
+ }
97
137
int64_t n_batch = input.size (0 );
98
138
c10::SmallVector<int64_t , 2 > offsets = {0 , 0 };
99
139
c10::SmallVector<int64_t , 2 > sizes = {n_batch, 1 };
@@ -124,27 +164,41 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
124
164
void softmax_focal_loss_backward_npu (Tensor input, Tensor target, Tensor weight,
125
165
Tensor buff, Tensor grad_input,
126
166
float gamma, float alpha) {
167
+ at::Tensor input_y = input;
168
+ at::Tensor grad_input_y = grad_input;
169
+ bool is_half = input.scalar_type () == at::kHalf ;
170
+ if (is_half) {
171
+ input_y = input.to (at::kFloat );
172
+ grad_input_y = grad_input.to (at::kFloat );
173
+ }
127
174
int64_t n_class = input.size (1 );
128
175
at::Tensor target_y = at::one_hot (target, n_class);
129
176
target_y = target_y.to (at::kInt );
130
177
at::Tensor grad_up = at::ones_like (input);
131
178
int64_t weight_size = weight.size (0 );
132
- at::Tensor weight_y = at::ones_like (input );
179
+ at::Tensor weight_y = at::ones_like (input_y );
133
180
if (weight_size > 0 ) {
134
181
weight_y = at::broadcast_to (weight, input.sizes ());
182
+ if (is_half) {
183
+ weight_y = weight_y.to (at::kFloat );
184
+ }
135
185
}
136
186
OpCommand cmd;
137
187
string reduction = " none" ;
138
188
cmd.Name (" SoftmaxFocalLossGrad" )
139
- .Input (input )
189
+ .Input (input_y )
140
190
.Input (target_y)
141
191
.Input (grad_up)
142
192
.Input (weight_y)
143
- .Output (grad_input )
193
+ .Output (grad_input_y )
144
194
.Attr (" gamma" , gamma )
145
195
.Attr (" alpha" , alpha)
146
196
.Attr (" reduction" , reduction)
147
197
.Run ();
198
+ if (is_half) {
199
+ grad_input_y = grad_input_y.to (at::kHalf );
200
+ }
201
+ grad_input.copy_ (grad_input_y);
148
202
}
149
203
150
204
void softmax_focal_loss_backward_impl (Tensor input, Tensor target,
0 commit comments