Skip to content

Commit cc78222

Browse files
committed
generic: sycl: RNN Vanilla BWD
1 parent cefc0be commit cc78222

File tree

8 files changed

+1669
-572
lines changed

8 files changed

+1669
-572
lines changed

src/gpu/generic/sycl/rnn/cell_common.cpp

+99-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace sycl {
2727
using namespace dnnl::impl::utils;
2828
using namespace rnn_utils;
2929

30-
status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
30+
status_t ref_rnn_fwd_t::cell_execution(const cell_ctx_t &cell_struct) {
3131

3232
auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay,
3333
cell_struct.lay, cell_struct.dir, cell_struct.dir, cell_struct.iter,
@@ -48,11 +48,11 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
4848
auto wei_iter
4949
= cell_struct.user_data.wei_iter(cell_struct.lay, cell_struct.dir);
5050

51-
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
52-
cell_layer, scratch_gates, gemm_layer_fwd));
51+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
52+
cell_layer, scratch_gates, matmul_layer_fwd));
5353

54-
CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
55-
cell_iter, scratch_gates, gemm_iter_fwd));
54+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
55+
cell_iter, scratch_gates, matmul_iter_fwd));
5656

5757
CHECK(rnn_bias(cell_struct.ctx, cell_struct.rnn.mb, cell_struct.rnn.dhc,
5858
cell_struct.iter, cell_struct.lay, cell_struct.dir,
@@ -61,6 +61,100 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
6161
return status::success;
6262
}
6363

64+
status_t ref_rnn_bwd_t::cell_execution(const cell_ctx_t &cell_struct) {
65+
66+
auto wei_layer = cell_struct.user_data.wei_layer(
67+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
68+
auto wei_iter = cell_struct.user_data.wei_iter(
69+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
70+
71+
auto ws_gates = cell_struct.workspace.gates(
72+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
73+
cell_struct.rnn.n_iter - cell_struct.iter - 1);
74+
75+
// take into account reading first layer from end of state
76+
// subsequent layers at written at forward pass'
77+
// last layer step location in state
78+
auto dsl_lay_off = cell_struct.iter == 0
79+
? cell_struct.rnn.n_layer * 2 - cell_struct.lay - 1
80+
: cell_struct.rnn.n_layer - cell_struct.lay - 1;
81+
// take into account for bidirectional case
82+
// any timesteps after first one
83+
// write out from bottom up of timestep state location
84+
auto dsl_iter_offset = cell_struct.iter == 0 ? 0
85+
: cell_struct.dir == 1
86+
? cell_struct.iter
87+
: cell_struct.rnn.n_iter - cell_struct.iter + 1;
88+
auto diff_cell_layer = cell_struct.scratch.diff_states(
89+
dsl_lay_off, cell_struct.dir, dsl_iter_offset);
90+
91+
// account for bidirectional cases needing to jumping over
92+
// n_iter state blocks
93+
auto dir_off = cell_struct.dir == 1
94+
? cell_struct.iter + 1
95+
: cell_struct.rnn.n_iter - cell_struct.iter;
96+
auto diff_cell_iter = cell_struct.scratch.diff_states(
97+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
98+
dir_off);
99+
100+
auto wei_cell_layer = cell_struct.workspace.states_range(
101+
cell_struct.rnn.n_layer - 1 - cell_struct.lay,
102+
cell_struct.rnn.n_layer - 1 - cell_struct.lay, cell_struct.dir,
103+
cell_struct.dir, cell_struct.rnn.n_iter - cell_struct.iter - 1,
104+
cell_struct.rnn.n_iter - cell_struct.iter - 1);
105+
106+
auto wci_offset = cell_struct.rnn.n_iter - cell_struct.iter - 2;
107+
108+
if (cell_struct.rnn.n_dir == 2) {
109+
wci_offset = cell_struct.rnn.n_iter - cell_struct.iter - 1 == 0
110+
? cell_struct.rnn.n_iter - cell_struct.iter
111+
- cell_struct.rnn.n_iter - 3
112+
: cell_struct.rnn.n_iter - cell_struct.iter - 2;
113+
}
114+
115+
auto wei_cell_iter = cell_struct.workspace.states_range(
116+
cell_struct.rnn.n_layer - cell_struct.lay,
117+
cell_struct.rnn.n_layer - cell_struct.lay, cell_struct.dir,
118+
cell_struct.dir, wci_offset, wci_offset);
119+
120+
auto diff_gates = cell_struct.scratch.diff_gates(0);
121+
122+
CHECK(rnn_bias(cell_struct.ctx, cell_struct.rnn.mb, cell_struct.rnn.dhc,
123+
cell_struct.iter, cell_struct.lay, cell_struct.dir,
124+
cell_struct.rnn.n_layer, diff_cell_layer, diff_cell_iter,
125+
cell_struct.user_data, ws_gates, diff_gates));
126+
127+
auto dsi_offset = cell_struct.dir == 1
128+
? cell_struct.iter + 1
129+
: cell_struct.rnn.n_iter - cell_struct.iter;
130+
131+
auto diff_states_layer = cell_struct.scratch.diff_states(
132+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir,
133+
dsi_offset, 0);
134+
135+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
136+
diff_gates, diff_states_layer, matmul_iter_bwd));
137+
138+
auto diff_states_iter = cell_struct.scratch.diff_states(
139+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir,
140+
dsi_offset, 1);
141+
142+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
143+
diff_gates, diff_states_iter, matmul_layer_bwd));
144+
145+
auto diff_wei_layer = cell_struct.user_data.diff_wei_layer(
146+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
147+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, diff_gates,
148+
wei_cell_layer, diff_wei_layer, matmul_diff_wei_layer));
149+
150+
auto diff_wei_iter = cell_struct.user_data.diff_wei_iter(
151+
cell_struct.rnn.n_layer - cell_struct.lay - 1, cell_struct.dir);
152+
CHECK(matmul_primitive(cell_struct.engine, cell_struct.ctx, diff_gates,
153+
wei_cell_iter, diff_wei_iter, matmul_diff_wei_iter));
154+
155+
return status::success;
156+
}
157+
64158
} // namespace sycl
65159
} // namespace generic
66160
} // namespace gpu

0 commit comments

Comments
 (0)