Skip to content

Commit cffea66

Browse files
authored
Refactor Leaky ReLU to Use Activations in Backward (#2262)
* Change leaky relu backprop to use activations * Update backprop requirements
1 parent 369eae5 commit cffea66

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

include/lbann/layers/activations/leaky_relu.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class leaky_relu_layer : public data_type_layer<TensorDataType>
128128
bool can_run_inplace() const override { return true; }
129129
int get_backprop_requirements() const override
130130
{
131-
return ERROR_SIGNALS | PREV_ACTIVATIONS;
131+
return ERROR_SIGNALS | ACTIVATIONS;
132132
}
133133

134134
private:

src/layers/activations/leaky_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ template <typename TensorDataType, data_layout Layout, El::Device Device>
8484
void leaky_relu_layer<TensorDataType, Layout, Device>::bp_compute()
8585
{
8686
local_bp<TensorDataType>(this->m_negative_slope,
87-
this->get_local_prev_activations(),
87+
this->get_local_activations(),
8888
this->get_local_prev_error_signals(),
8989
this->get_local_error_signals());
9090
}

src/layers/activations/leaky_relu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void bp_compute_distconv(
179179
TensorDataType negative_slope)
180180
{
181181
assert_always(Layout == data_layout::DATA_PARALLEL);
182-
dc.m_leaky_relu->backward(dc.get_prev_activations(),
182+
dc.m_leaky_relu->backward(dc.get_activations(),
183183
dc.get_prev_error_signals(),
184184
negative_slope,
185185
dc.get_error_signals());
@@ -211,7 +211,7 @@ void leaky_relu_layer<TensorDataType, Layout, Device>::bp_compute()
211211
}
212212
#endif // LBANN_HAS_DISTCONV
213213
local_bp(this->m_negative_slope,
214-
this->get_local_prev_activations(),
214+
this->get_local_activations(),
215215
this->get_local_prev_error_signals(),
216216
this->get_local_error_signals());
217217
}

0 commit comments

Comments
 (0)