Skip to content

Commit ea09e16

Browse files
authored
Merge pull request #207 from NVlabs/overhaul-parameter-init
Overhaul parameter init
2 parents ed53430 + 3a3667a commit ea09e16

17 files changed

+191
-217
lines changed

DOCUMENTATION.md

-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ Lightning fast implementation of small multi-layer perceptrons (MLPs). Restricte
3737
"n_neurons": 128, // Neurons in each hidden layer.
3838
// May only be 16, 32, 64, or 128.
3939
"n_hidden_layers": 5, // Number of hidden layers.
40-
"feedback_alignment": false // Use feedback alignment
41-
// [Lillicrap et al. 2016].
4240
}
4341
```
4442

include/tiny-cuda-nn/cpp_api.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Module {
8686
return m_param_precision;
8787
}
8888

89-
virtual void initialize_params(size_t seed, float* params_full_precision) = 0;
89+
virtual void initialize_params(size_t seed, float* params_full_precision, float scale = 1.0f) = 0;
9090

9191
virtual uint32_t n_output_dims() const = 0;
9292
EPrecision output_precision() const {

include/tiny-cuda-nn/encoding.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ class Encoding : public DifferentiableObject<float, T, T> {
7676
virtual MatrixLayout preferred_output_layout() const = 0;
7777

7878
// By default, an encoding has no parameters
79-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override { }
80-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override { }
79+
void set_params_impl(T* params, T* inference_params, T* gradients) override { }
80+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override { }
8181
size_t n_params() const override { return 0; }
8282

8383
std::vector<std::pair<uint32_t, uint32_t>> layer_sizes() const override { return {}; }

include/tiny-cuda-nn/encodings/composite.h

+5-19
Original file line numberDiff line numberDiff line change
@@ -403,32 +403,18 @@ class CompositeEncoding : public Encoding<T> {
403403
return m_nested.empty() ? AoS : m_nested.front()->preferred_output_layout();
404404
}
405405

406-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
406+
void set_params_impl(T* params, T* inference_params, T* gradients) override {
407407
size_t offset = 0;
408408
for (auto& nested : m_nested) {
409-
nested->set_params(
410-
params + offset,
411-
inference_params + offset,
412-
backward_params + offset,
413-
gradients + offset
414-
);
409+
nested->set_params(params + offset, inference_params + offset, gradients + offset);
415410
offset += nested->n_params();
416411
}
417412
}
418413

419-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
420-
size_t offset = 0;
414+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override {
421415
for (auto& nested : m_nested) {
422-
nested->initialize_params(
423-
rnd,
424-
params_full_precision + offset,
425-
params + offset,
426-
inference_params + offset,
427-
backward_params + offset,
428-
gradients + offset,
429-
scale
430-
);
431-
offset += nested->n_params();
416+
nested->initialize_params(rnd, params_full_precision, scale);
417+
params_full_precision += nested->n_params();
432418
}
433419
}
434420

include/tiny-cuda-nn/encodings/grid.h

+9-20
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
10791079
this->m_max_level_gpu,
10801080
m_interpolation_type,
10811081
m_grid_type,
1082-
use_inference_params ? m_grid_inference : m_grid,
1082+
use_inference_params ? this->inference_params() : this->params(),
10831083
forward->positions.data() ? forward->positions.view() : input.view(),
10841084
encoded_positions_soa,
10851085
forward->dy_dx.data()
@@ -1144,7 +1144,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
11441144
grid_gradient_tmp = allocate_workspace(stream, m_n_params * sizeof(grad_t));
11451145
grid_gradient = (grad_t*)grid_gradient_tmp.data();
11461146
} else {
1147-
grid_gradient = (grad_t*)m_grid_gradient;
1147+
grid_gradient = (grad_t*)this->gradients();
11481148
}
11491149

11501150
if (param_gradients_mode == EGradientMode::Overwrite) {
@@ -1173,7 +1173,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
11731173
);
11741174

11751175
if (!std::is_same<grad_t, T>::value) {
1176-
parallel_for_gpu(stream, n_params(), [grad=m_grid_gradient, grad_tmp=grid_gradient] __device__ (size_t i) {
1176+
parallel_for_gpu(stream, n_params(), [grad=this->gradients(), grad_tmp=grid_gradient] __device__ (size_t i) {
11771177
grad[i] = (T)grad_tmp[i];
11781178
});
11791179
}
@@ -1238,7 +1238,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
12381238
grid_gradient_tmp = allocate_workspace(stream, m_n_params * sizeof(grad_t));
12391239
grid_gradient = (grad_t*)grid_gradient_tmp.data();
12401240
} else {
1241-
grid_gradient = (grad_t*)m_grid_gradient;
1241+
grid_gradient = (grad_t*)this->gradients();
12421242
}
12431243

12441244
if (param_gradients_mode == EGradientMode::Overwrite) {
@@ -1270,7 +1270,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
12701270
);
12711271

12721272
if (!std::is_same<grad_t, T>::value) {
1273-
parallel_for_gpu(stream, n_params(), [grad=m_grid_gradient, grad_tmp=grid_gradient] __device__ (size_t i) {
1273+
parallel_for_gpu(stream, n_params(), [grad=this->gradients(), grad_tmp=grid_gradient] __device__ (size_t i) {
12741274
grad[i] = (T)grad_tmp[i];
12751275
});
12761276
}
@@ -1312,7 +1312,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
13121312
dL_ddLdinput.view(),
13131313
forward.positions.data() ? forward.positions.view() : input.view(),
13141314
dL_dy_rm,
1315-
use_inference_params ? m_grid_inference : m_grid,
1315+
use_inference_params ? this->inference_params() : this->params(),
13161316
// outputs
13171317
dL_dinput->view()
13181318
);
@@ -1348,17 +1348,11 @@ class GridEncodingTemplated : public GridEncoding<T> {
13481348
return SoA;
13491349
}
13501350

1351-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
1352-
m_grid = params;
1353-
m_grid_inference = inference_params;
1354-
m_grid_gradient = gradients;
1355-
}
1356-
1357-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
1358-
set_params(params, inference_params, backward_params, gradients);
1351+
void set_params_impl(T* params, T* inference_params, T* gradients) override { }
13591352

1353+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override {
13601354
// Initialize the hashgrid from the GPU, because the number of parameters can be quite large.
1361-
generate_random_uniform<float>(rnd, n_params(), params_full_precision, -1e-4f, 1e-4f);
1355+
generate_random_uniform<float>(rnd, n_params(), params_full_precision, -1e-4f * scale, 1e-4f * scale);
13621356
}
13631357

13641358
size_t n_params() const override {
@@ -1434,11 +1428,6 @@ class GridEncodingTemplated : public GridEncoding<T> {
14341428
bool m_stochastic_interpolation;
14351429
InterpolationType m_interpolation_type;
14361430
GridType m_grid_type;
1437-
1438-
// Storage of params
1439-
T* m_grid;
1440-
T* m_grid_inference;
1441-
T* m_grid_gradient;
14421431
};
14431432

14441433
template <typename T, uint32_t N_FEATURES_PER_LEVEL, HashType HASH_TYPE>

include/tiny-cuda-nn/network.h

-6
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,6 @@
3434

3535
TCNN_NAMESPACE_BEGIN
3636

37-
enum class WeightUsage {
38-
Inference,
39-
Forward,
40-
Backward,
41-
};
42-
4337
Activation string_to_activation(const std::string& activation_name);
4438
std::string to_string(Activation activation);
4539

include/tiny-cuda-nn/network_with_input_encoding.h

+8-35
Original file line numberDiff line numberDiff line change
@@ -110,48 +110,21 @@ class NetworkWithInputEncoding : public Network<float, T> {
110110
}
111111
}
112112

113-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override {
113+
void set_params_impl(T* params, T* inference_params, T* gradients) override {
114114
size_t offset = 0;
115-
m_network->set_params(
116-
params + offset,
117-
inference_params + offset,
118-
backward_params + offset,
119-
gradients + offset
120-
);
115+
m_network->set_params(params + offset, inference_params + offset, gradients + offset);
121116
offset += m_network->n_params();
122117

123-
m_encoding->set_params(
124-
params + offset,
125-
inference_params + offset,
126-
backward_params + offset,
127-
gradients + offset
128-
);
118+
m_encoding->set_params(params + offset, inference_params + offset, gradients + offset);
129119
offset += m_encoding->n_params();
130120
}
131121

132-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override {
133-
size_t offset = 0;
134-
m_network->initialize_params(
135-
rnd,
136-
params_full_precision + offset,
137-
params + offset,
138-
inference_params + offset,
139-
backward_params + offset,
140-
gradients + offset,
141-
scale
142-
);
143-
offset += m_network->n_params();
122+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override {
123+
m_network->initialize_params(rnd, params_full_precision, scale);
124+
params_full_precision += m_network->n_params();
144125

145-
m_encoding->initialize_params(
146-
rnd,
147-
params_full_precision + offset,
148-
params + offset,
149-
inference_params + offset,
150-
backward_params + offset,
151-
gradients + offset,
152-
scale
153-
);
154-
offset += m_encoding->n_params();
126+
m_encoding->initialize_params(rnd, params_full_precision, scale);
127+
params_full_precision += m_encoding->n_params();
155128
}
156129

157130
size_t n_params() const override {

include/tiny-cuda-nn/networks/cutlass_mlp.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class CutlassMLP : public Network<T> {
6666
EGradientMode param_gradients_mode = EGradientMode::Overwrite
6767
) override;
6868

69-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override;
70-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override;
69+
void set_params_impl(T* params, T* inference_params, T* gradients) override;
70+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override;
7171

7272
GPUMatrix<T, RM>& input_weight_matrix(bool inference) {
7373
auto& weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;

include/tiny-cuda-nn/networks/fully_fused_mlp.h

+12-33
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TCNN_NAMESPACE_BEGIN
4242
template <typename T, int WIDTH>
4343
class FullyFusedMLP : public Network<T> {
4444
public:
45-
FullyFusedMLP(uint32_t input_width, uint32_t output_width, uint32_t n_hidden_layers, bool use_feedback_alignment, Activation activation, Activation output_activation);
45+
FullyFusedMLP(uint32_t input_width, uint32_t output_width, uint32_t n_hidden_layers, Activation activation, Activation output_activation);
4646

4747
void inference_mixed_precision_impl(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_params = true) override;
4848

@@ -59,37 +59,22 @@ class FullyFusedMLP : public Network<T> {
5959
EGradientMode param_gradients_mode = EGradientMode::Overwrite
6060
) override;
6161

62-
void set_params(T* params, T* inference_params, T* backward_params, T* gradients) override;
63-
void initialize_params(pcg32& rnd, float* params_full_precision, T* params, T* inference_params, T* backward_params, T* gradients, float scale = 1) override;
62+
void set_params_impl(T* params, T* inference_params, T* gradients) override;
63+
void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override;
6464

65-
GPUMatrix<T, RM>& input_weight_matrix(WeightUsage usage) {
66-
switch (usage) {
67-
case WeightUsage::Inference: return m_weight_matrices_inference.front();
68-
case WeightUsage::Forward: return m_weight_matrices.front();
69-
case WeightUsage::Backward: return m_weight_matrices_backward.front();
70-
}
71-
72-
throw std::runtime_error{"Invalid weight usage."};
65+
GPUMatrix<T, RM>& input_weight_matrix(bool inference) {
66+
auto& weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
67+
return weight_matrices.front();
7368
}
7469

75-
GPUMatrix<T, RM>& weight_matrix_at(WeightUsage usage, uint32_t idx) {
76-
switch (usage) {
77-
case WeightUsage::Inference: return m_weight_matrices_inference.at(1 + idx);
78-
case WeightUsage::Forward: return m_weight_matrices.at(1 + idx);
79-
case WeightUsage::Backward: return m_weight_matrices_backward.at(1 + idx);
80-
}
81-
82-
throw std::runtime_error{"Invalid weight usage."};
70+
GPUMatrix<T, RM>& weight_matrix_at(bool inference, uint32_t idx) {
71+
auto& weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
72+
return weight_matrices.at(1 + idx);
8373
}
8474

85-
GPUMatrix<T, RM>& output_weight_matrix(WeightUsage usage) {
86-
switch (usage) {
87-
case WeightUsage::Inference: return m_weight_matrices_inference.back();
88-
case WeightUsage::Forward: return m_weight_matrices.back();
89-
case WeightUsage::Backward: return m_weight_matrices_backward.back();
90-
}
91-
92-
throw std::runtime_error{"Invalid weight usage."};
75+
GPUMatrix<T, RM>& output_weight_matrix(bool inference) {
76+
auto& weight_matrices = inference ? m_weight_matrices_inference : m_weight_matrices;
77+
return weight_matrices.back();
9378
}
9479

9580
GPUMatrix<T, RM>& input_gradient_matrix() {
@@ -156,7 +141,6 @@ class FullyFusedMLP : public Network<T> {
156141
{"output_activation", to_string(m_output_activation)},
157142
{"n_neurons", m_network_width},
158143
{"n_hidden_layers", m_n_hidden_layers},
159-
{"feedback_alignment", m_use_feedback_alignment},
160144
};
161145
}
162146

@@ -178,16 +162,11 @@ class FullyFusedMLP : public Network<T> {
178162
Activation m_activation;
179163
Activation m_output_activation;
180164

181-
bool m_use_feedback_alignment = false;
182-
183165
// Storage of params
184166
std::vector<GPUMatrix<T, RM>> m_weight_matrices;
185167
std::vector<GPUMatrix<T, RM>> m_weight_matrices_inference;
186-
std::vector<GPUMatrix<T, RM>> m_weight_matrices_backward;
187168
size_t m_total_n_params;
188169

189-
std::vector<GPUMatrix<float, RM>> m_weight_matrices_full_precision;
190-
191170
std::vector<GPUMatrix<T, RM>> m_gradient_matrices;
192171
};
193172

0 commit comments

Comments
 (0)