@@ -39,13 +39,23 @@ Tensor nn_relu(Tensor self) {
39
39
/* nn.softmax */
40
40
static Tensor GradFn_softmax (Tensor self , int i ) {
41
41
Tensor input = self .node -> inputs [i ];
42
- Tensor res = Tensor_new (input .shape , false);
43
- for (int j = 0 ; j < input .data -> numel ; j ++ ) {
44
- float softmax_j = self .data -> flex [j ];
45
- for (int k = 0 ; k < input .data -> numel ; k ++ ) {
46
- float softmax_k = self .data -> flex [k ];
47
- float delta_jk = (j == k ) ? 1.0f : 0.0f ;
48
- res .data -> flex [j * input .data -> numel + k ] = softmax_j * (delta_jk - softmax_k );
42
+ int num_classes = input .shape [TensorShape_dim (input .shape ) - 1 ];
43
+ int batch_size = input .data -> numel / num_classes ;
44
+
45
+ Tensor res = Tensor_zeros (input .shape , false);
46
+
47
+ for (int batch = 0 ; batch < batch_size ; batch ++ ) {
48
+ float grad_sum = 0.0f ;
49
+
50
+ for (int k = 0 ; k < num_classes ; k ++ ) {
51
+ grad_sum += self .data -> flex [batch * num_classes + k ] * self .node -> grad .data -> flex [batch * num_classes + k ];
52
+ }
53
+
54
+ for (int j = 0 ; j < num_classes ; j ++ ) {
55
+ float softmax_j = self .data -> flex [batch * num_classes + j ];
56
+ float grad_j = self .node -> grad .data -> flex [batch * num_classes + j ];
57
+
58
+ res .data -> flex [batch * num_classes + j ] = softmax_j * (grad_j - grad_sum );
49
59
}
50
60
}
51
61
return res ;
@@ -88,6 +98,27 @@ Tensor nn_softmax(Tensor self) {
88
98
return res ;
89
99
}
90
100
101
+ Tensor GradFn_crossentropy (Tensor self , int i ) {
102
+ Tensor y_true = self .node -> inputs [0 ];
103
+ Tensor y_pred = self .node -> inputs [1 ];
104
+
105
+ Tensor grad = Tensor_zeros (y_pred .shape , false);
106
+
107
+ if (i == 1 ) {
108
+ // f'(y_true, y_pred) = -y_true / y_pred
109
+ int n_samples = y_pred .shape [0 ];
110
+ int n_classes = y_pred .shape [1 ];
111
+
112
+ for (int s = 0 ; s < n_samples ; s ++ ) {
113
+ for (int c = 0 ; c < n_classes ; c ++ ) {
114
+ int idx = s * n_classes + c ;
115
+ grad .data -> flex [idx ] = - y_true .data -> flex [idx ] / y_pred .data -> flex [idx ];
116
+ }
117
+ }
118
+ }
119
+ return grad ;
120
+ }
121
+
91
122
/* nn.cross_entropy */
92
123
Tensor nn_crossentropy (Tensor y_true , Tensor y_pred ) {
93
124
// y_true: [None, n_classes]
@@ -101,7 +132,7 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
101
132
assert (n_classes == y_pred .shape [1 ]);
102
133
103
134
bool requires_grad = !cten_is_eval () && (y_true .node != NULL || y_pred .node != NULL );
104
- Tensor res = Tensor_new ((TensorShape ){n_samples }, requires_grad );
135
+ Tensor res = Tensor_new ((TensorShape ){n_samples , 1 }, requires_grad );
105
136
for (int i = 0 ; i < n_samples ; i ++ ) {
106
137
float loss = 0 ;
107
138
for (int j = 0 ; j < n_classes ; j ++ ) {
@@ -110,5 +141,12 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
110
141
}
111
142
res .data -> flex [i ] = - loss ;
112
143
}
144
+
145
+ if (requires_grad ) {
146
+ res .node -> grad_fn = GradFn_crossentropy ;
147
+ res .node -> inputs [0 ] = y_true ;
148
+ res .node -> inputs [1 ] = y_pred ;
149
+ res .node -> n_inputs = 2 ;
150
+ }
113
151
return Tensor_mean (res );
114
152
}
0 commit comments