File tree 2 files changed +4
-1
lines changed
2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ class ReLu : public Module
11
11
std::vector<float > backward (const std::vector<float > &grad_output) override ;
12
12
void update (float lr) override ;
13
13
private:
14
- std::vector<bool > zeroed;
14
+ std::vector<bool > zeroed; // Store for backpropagation
15
15
};
16
16
17
17
#endif
Original file line number Diff line number Diff line change @@ -19,6 +19,8 @@ class LinearLayer : public Module
19
19
vector<float > forward (const vector<float > &input) override ;
20
20
vector<float > backward (const vector<float > &grad_output) override ;
21
21
void update (float lr) override ;
22
+
23
+ // Expose weights and biases for comparison
22
24
vector<vector<float >> weights;
23
25
vector<float > bias;
24
26
vector<vector<float >> grad_weights;
@@ -28,6 +30,7 @@ class LinearLayer : public Module
28
30
int input_size;
29
31
int output_size;
30
32
33
+ // Store input for backward pass
31
34
vector<float > input;
32
35
};
33
36
You can’t perform that action at this time.
0 commit comments