@@ -42,7 +42,7 @@ TCNN_NAMESPACE_BEGIN
42
42
template <typename T, int WIDTH>
43
43
class FullyFusedMLP : public Network <T> {
44
44
public:
45
- FullyFusedMLP (uint32_t input_width, uint32_t output_width, uint32_t n_hidden_layers, bool use_feedback_alignment, Activation activation, Activation output_activation);
45
+ FullyFusedMLP (uint32_t input_width, uint32_t output_width, uint32_t n_hidden_layers, Activation activation, Activation output_activation);
46
46
47
47
void inference_mixed_precision_impl (cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_params = true ) override ;
48
48
@@ -59,37 +59,22 @@ class FullyFusedMLP : public Network<T> {
59
59
EGradientMode param_gradients_mode = EGradientMode::Overwrite
60
60
) override ;
61
61
62
- void set_params (T* params, T* inference_params, T* backward_params , T* gradients) override ;
63
- void initialize_params (pcg32& rnd, float * params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1 ) override ;
62
+ void set_params_impl (T* params, T* inference_params, T* gradients) override ;
63
+ void initialize_params (pcg32& rnd, float * params_full_precision, float scale = 1 ) override ;
64
64
65
- GPUMatrix<T, RM>& input_weight_matrix (WeightUsage usage) {
66
- switch (usage) {
67
- case WeightUsage::Inference: return m_weight_matrices_inference.front ();
68
- case WeightUsage::Forward: return m_weight_matrices.front ();
69
- case WeightUsage::Backward: return m_weight_matrices_backward.front ();
70
- }
71
-
72
- throw std::runtime_error{" Invalid weight usage." };
65
+ GPUMatrix<T, RM>& input_weight_matrix (bool inference) {
66
+ auto & weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
67
+ return weight_matrices.front ();
73
68
}
74
69
75
- GPUMatrix<T, RM>& weight_matrix_at (WeightUsage usage, uint32_t idx) {
76
- switch (usage) {
77
- case WeightUsage::Inference: return m_weight_matrices_inference.at (1 + idx);
78
- case WeightUsage::Forward: return m_weight_matrices.at (1 + idx);
79
- case WeightUsage::Backward: return m_weight_matrices_backward.at (1 + idx);
80
- }
81
-
82
- throw std::runtime_error{" Invalid weight usage." };
70
+ GPUMatrix<T, RM>& weight_matrix_at (bool inference, uint32_t idx) {
71
+ auto & weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
72
+ return weight_matrices.at (1 + idx);
83
73
}
84
74
85
- GPUMatrix<T, RM>& output_weight_matrix (WeightUsage usage) {
86
- switch (usage) {
87
- case WeightUsage::Inference: return m_weight_matrices_inference.back ();
88
- case WeightUsage::Forward: return m_weight_matrices.back ();
89
- case WeightUsage::Backward: return m_weight_matrices_backward.back ();
90
- }
91
-
92
- throw std::runtime_error{" Invalid weight usage." };
75
+ GPUMatrix<T, RM>& output_weight_matrix (bool inference) {
76
+ auto & weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
77
+ return weight_matrices.back ();
93
78
}
94
79
95
80
GPUMatrix<T, RM>& input_gradient_matrix () {
@@ -156,7 +141,6 @@ class FullyFusedMLP : public Network<T> {
156
141
{" output_activation" , to_string (m_output_activation)},
157
142
{" n_neurons" , m_network_width},
158
143
{" n_hidden_layers" , m_n_hidden_layers},
159
- {" feedback_alignment" , m_use_feedback_alignment},
160
144
};
161
145
}
162
146
@@ -178,16 +162,11 @@ class FullyFusedMLP : public Network<T> {
178
162
Activation m_activation;
179
163
Activation m_output_activation;
180
164
181
- bool m_use_feedback_alignment = false ;
182
-
183
165
// Storage of params
184
166
std::vector<GPUMatrix<T, RM>> m_weight_matrices;
185
167
std::vector<GPUMatrix<T, RM>> m_weight_matrices_inference;
186
- std::vector<GPUMatrix<T, RM>> m_weight_matrices_backward;
187
168
size_t m_total_n_params;
188
169
189
- std::vector<GPUMatrix<float , RM>> m_weight_matrices_full_precision;
190
-
191
170
std::vector<GPUMatrix<T, RM>> m_gradient_matrices;
192
171
};
193
172
0 commit comments