@@ -22,12 +22,13 @@ static Tensor GradFn_relu(Tensor self, int i) {
22
22
}
23
23
24
24
Tensor nn_relu (Tensor self ) {
25
- Tensor res = Tensor_new (self .shape , self .node != NULL );
25
+ bool requires_grad = !cten_is_eval () && self .node != NULL ;
26
+ Tensor res = Tensor_new (self .shape , requires_grad );
26
27
for (int i = 0 ; i < self .data -> numel ; i ++ ) {
27
28
res .data -> flex [i ] = fmaxf (0 , self .data -> flex [i ]);
28
29
}
29
30
30
- if (self . node != NULL ) {
31
+ if (requires_grad ) {
31
32
res .node -> grad_fn = GradFn_relu ;
32
33
res .node -> inputs [0 ] = self ;
33
34
res .node -> n_inputs = 1 ;
@@ -51,8 +52,8 @@ static Tensor GradFn_softmax(Tensor self, int i) {
51
52
}
52
53
53
54
Tensor nn_softmax (Tensor self ) {
54
- Tensor res = Tensor_new ( self . shape , self .node != NULL ) ;
55
-
55
+ bool requires_grad = ! cten_is_eval () && self .node != NULL ;
56
+ Tensor res = Tensor_new ( self . shape , requires_grad );
56
57
int self_dim = TensorShape_dim (self .shape );
57
58
assert (self_dim > 0 );
58
59
int last_dim_size = self .shape [self_dim - 1 ];
@@ -79,12 +80,11 @@ Tensor nn_softmax(Tensor self) {
79
80
}
80
81
}
81
82
82
- if (self . node != NULL ) {
83
+ if (requires_grad ) {
83
84
res .node -> grad_fn = GradFn_softmax ;
84
85
res .node -> inputs [0 ] = self ;
85
86
res .node -> n_inputs = 1 ;
86
87
}
87
-
88
88
return res ;
89
89
}
90
90
@@ -100,11 +100,13 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
100
100
assert (n_samples == y_pred .shape [0 ]);
101
101
assert (n_classes == y_pred .shape [1 ]);
102
102
103
- Tensor res = Tensor_new ((TensorShape ){n_samples }, true);
103
+ bool requires_grad = !cten_is_eval () && (y_true .node != NULL || y_pred .node != NULL );
104
+ Tensor res = Tensor_new ((TensorShape ){n_samples }, requires_grad );
104
105
for (int i = 0 ; i < n_samples ; i ++ ) {
105
106
float loss = 0 ;
106
107
for (int j = 0 ; j < n_classes ; j ++ ) {
107
- loss += y_true .data -> flex [i * n_classes + j ] * logf (y_pred .data -> flex [i * n_classes + j ]);
108
+ loss +=
109
+ y_true .data -> flex [i * n_classes + j ] * logf (y_pred .data -> flex [i * n_classes + j ]);
108
110
}
109
111
res .data -> flex [i ] = - loss ;
110
112
}
0 commit comments