Skip to content

Commit 2fbdb2f

Browse files
committed
generic: sycl: RNN Vanilla BWD
1 parent cba91c3 commit 2fbdb2f

File tree

8 files changed

+1261
-336
lines changed

8 files changed

+1261
-336
lines changed

src/gpu/generic/sycl/rnn/cell_common.cpp

+105-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ namespace sycl {
2727
using namespace dnnl::impl::utils;
2828
using namespace rnn_utils;
2929

30-
status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
30+
template <prop_kind_t aprop>
31+
status_t _ref_rnn_common_t<aprop>::cell_execution(
32+
const cell_ctx_t &cell_struct) {
3133

3234
auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay,
3335
cell_struct.lay, cell_struct.dir, cell_struct.dir, cell_struct.iter,
@@ -48,19 +50,114 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
4850
auto wei_iter
4951
= cell_struct.user_data.wei_iter(cell_struct.lay, cell_struct.dir);
5052

51-
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
52-
cell_layer, scratch_gates, gemm_layer_fwd));
53+
if (aprop == prop_kind::forward) {
5354

54-
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
55-
cell_iter, scratch_gates, gemm_iter_fwd));
55+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
56+
cell_layer, scratch_gates, gemm_layer_fwd));
5657

57-
CHECK(rnn_bias(cell_struct.ctx, cell_struct.rnn.mb, cell_struct.rnn.dhc,
58-
cell_struct.iter, cell_struct.lay, cell_struct.dir,
59-
cell_struct.workspace, cell_struct.scratch, cell_struct.user_data));
58+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
59+
cell_iter, scratch_gates, gemm_iter_fwd));
60+
61+
CHECK(rnn_bias_fwd(cell_struct.ctx, cell_struct.rnn.mb,
62+
cell_struct.rnn.dhc, cell_struct.iter, cell_struct.lay,
63+
cell_struct.dir, cell_struct.workspace, cell_struct.scratch,
64+
cell_struct.user_data));
65+
} else { // backward
66+
67+
wei_layer = cell_struct.user_data.wei_layer(
68+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
69+
wei_iter = cell_struct.user_data.wei_iter(
70+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
71+
72+
auto ws_gates = cell_struct.workspace.gates(
73+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
74+
cell_struct.rnn.n_iter - cell_struct.iter - 1);
75+
76+
auto dsl_lay_off = cell_struct.iter == 0
77+
? cell_struct.rnn.n_layer * 2 - cell_struct.lay - 1
78+
: cell_struct.rnn.n_layer - cell_struct.lay - 1;
79+
80+
auto dsl_iter_offset = cell_struct.iter == 0 ? 0
81+
: cell_struct.dir == 1
82+
? cell_struct.iter
83+
: cell_struct.rnn.n_iter - cell_struct.iter + 1;
84+
auto diff_cell_layer = cell_struct.scratch.diff_states(
85+
dsl_lay_off, cell_struct.dir, dsl_iter_offset);
86+
87+
auto dir_off = cell_struct.dir == 1
88+
? cell_struct.iter + 1
89+
: cell_struct.rnn.n_iter - cell_struct.iter;
90+
auto diff_cell_iter = cell_struct.scratch.diff_states(
91+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
92+
dir_off);
93+
94+
auto wei_cell_layer = cell_struct.workspace.states_range(
95+
cell_struct.rnn.n_layer - 1 - cell_struct.lay,
96+
cell_struct.rnn.n_layer - 1 - cell_struct.lay, cell_struct.dir,
97+
cell_struct.dir, cell_struct.rnn.n_iter - cell_struct.iter - 1,
98+
cell_struct.rnn.n_iter - cell_struct.iter - 1);
99+
100+
auto wci_offset = cell_struct.rnn.n_iter - cell_struct.iter - 2;
101+
102+
if (cell_struct.rnn.n_dir == 2) {
103+
wci_offset = cell_struct.rnn.n_iter - cell_struct.iter - 1 == 0
104+
? cell_struct.rnn.n_iter - cell_struct.iter
105+
- cell_struct.rnn.n_iter - 3
106+
: cell_struct.rnn.n_iter - cell_struct.iter - 2;
107+
}
108+
109+
auto wei_cell_iter = cell_struct.workspace.states_range(
110+
cell_struct.rnn.n_layer - cell_struct.lay,
111+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
112+
cell_struct.dir, wci_offset, wci_offset);
113+
114+
auto diff_gates = cell_struct.scratch.diff_gates(0);
115+
116+
CHECK(rnn_bias_bwd(cell_struct.ctx, cell_struct.rnn.mb,
117+
cell_struct.rnn.dhc, cell_struct.iter, cell_struct.lay,
118+
cell_struct.dir, cell_struct.rnn.n_layer, diff_cell_layer,
119+
diff_cell_iter, cell_struct.user_data, ws_gates, diff_gates));
120+
121+
auto diff_states_1ay = cell_struct.scratch.diff_states(
122+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir,
123+
cell_struct.rnn.n_iter - cell_struct.iter);
124+
auto dsi_offset = cell_struct.dir == 1
125+
? cell_struct.iter + 1
126+
: cell_struct.rnn.n_iter - cell_struct.iter;
127+
128+
auto diff_states_layer = cell_struct.scratch.diff_states(
129+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir,
130+
dsi_offset, 0);
131+
132+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
133+
diff_gates, diff_states_layer, gemm_iter_bwd));
134+
135+
auto diff_states_iter = cell_struct.scratch.diff_states(
136+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir,
137+
dsi_offset, 1);
138+
139+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
140+
diff_gates, diff_states_iter, gemm_layer_bwd));
141+
142+
auto diff_wei_layer = cell_struct.user_data.diff_wei_layer(
143+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
144+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, diff_gates,
145+
wei_cell_layer, diff_wei_layer, gemm_diff_wei_layer));
146+
147+
auto diff_wei_iter = cell_struct.user_data.diff_wei_iter(
148+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
149+
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, diff_gates,
150+
wei_cell_iter, diff_wei_iter, gemm_diff_wei_iter));
151+
}
60152

61153
return status::success;
62154
}
63155

156+
template status_t _ref_rnn_common_t<prop_kind::forward>::cell_execution(
157+
const cell_ctx_t &cell_struct);
158+
template status_t _ref_rnn_common_t<prop_kind::backward>::cell_execution(
159+
const cell_ctx_t &cell_struct);
160+
64161
} // namespace sycl
65162
} // namespace generic
66163
} // namespace gpu

0 commit comments

Comments
 (0)