@@ -27,7 +27,7 @@ namespace sycl {
27
27
using namespace dnnl ::impl::utils;
28
28
using namespace rnn_utils ;
29
29
30
- status_t _ref_rnn_common_t ::cell_execution (const cell_ctx_t &cell_struct) {
30
+ status_t ref_rnn_fwd_t ::cell_execution (const cell_ctx_t &cell_struct) {
31
31
32
32
auto cell_layer = cell_struct.workspace .states_range (cell_struct.lay ,
33
33
cell_struct.lay , cell_struct.dir , cell_struct.dir , cell_struct.iter ,
@@ -48,11 +48,11 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
48
48
auto wei_iter
49
49
= cell_struct.user_data .wei_iter (cell_struct.lay , cell_struct.dir );
50
50
51
- CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_layer,
52
- cell_layer, scratch_gates, gemm_layer_fwd ));
51
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , wei_layer,
52
+ cell_layer, scratch_gates, matmul_layer_fwd ));
53
53
54
- CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_iter,
55
- cell_iter, scratch_gates, gemm_iter_fwd ));
54
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , wei_iter,
55
+ cell_iter, scratch_gates, matmul_iter_fwd ));
56
56
57
57
CHECK (rnn_bias (cell_struct.ctx , cell_struct.rnn .mb , cell_struct.rnn .dhc ,
58
58
cell_struct.iter , cell_struct.lay , cell_struct.dir ,
@@ -61,6 +61,100 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
61
61
return status::success;
62
62
}
63
63
64
+ status_t ref_rnn_bwd_t::cell_execution (const cell_ctx_t &cell_struct) {
65
+
66
+ auto wei_layer = cell_struct.user_data .wei_layer (
67
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
68
+ auto wei_iter = cell_struct.user_data .wei_iter (
69
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
70
+
71
+ auto ws_gates = cell_struct.workspace .gates (
72
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
73
+ cell_struct.rnn .n_iter - cell_struct.iter - 1 );
74
+
75
+ // take into account reading first layer from end of state
76
+ // subsequent layers at written at forward pass'
77
+ // last layer step location in state
78
+ auto dsl_lay_off = cell_struct.iter == 0
79
+ ? cell_struct.rnn .n_layer * 2 - cell_struct.lay - 1
80
+ : cell_struct.rnn .n_layer - cell_struct.lay - 1 ;
81
+ // take into account for bidirectional case
82
+ // any timesteps after first one
83
+ // write out from bottom up of timestep state location
84
+ auto dsl_iter_offset = cell_struct.iter == 0 ? 0
85
+ : cell_struct.dir == 1
86
+ ? cell_struct.iter
87
+ : cell_struct.rnn .n_iter - cell_struct.iter + 1 ;
88
+ auto diff_cell_layer = cell_struct.scratch .diff_states (
89
+ dsl_lay_off, cell_struct.dir , dsl_iter_offset);
90
+
91
+ // account for bidirectional cases needing to jumping over
92
+ // n_iter state blocks
93
+ auto dir_off = cell_struct.dir == 1
94
+ ? cell_struct.iter + 1
95
+ : cell_struct.rnn .n_iter - cell_struct.iter ;
96
+ auto diff_cell_iter = cell_struct.scratch .diff_states (
97
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
98
+ dir_off);
99
+
100
+ auto wei_cell_layer = cell_struct.workspace .states_range (
101
+ cell_struct.rnn .n_layer - 1 - cell_struct.lay ,
102
+ cell_struct.rnn .n_layer - 1 - cell_struct.lay , cell_struct.dir ,
103
+ cell_struct.dir , cell_struct.rnn .n_iter - cell_struct.iter - 1 ,
104
+ cell_struct.rnn .n_iter - cell_struct.iter - 1 );
105
+
106
+ auto wci_offset = cell_struct.rnn .n_iter - cell_struct.iter - 2 ;
107
+
108
+ if (cell_struct.rnn .n_dir == 2 ) {
109
+ wci_offset = cell_struct.rnn .n_iter - cell_struct.iter - 1 == 0
110
+ ? cell_struct.rnn .n_iter - cell_struct.iter
111
+ - cell_struct.rnn .n_iter - 3
112
+ : cell_struct.rnn .n_iter - cell_struct.iter - 2 ;
113
+ }
114
+
115
+ auto wei_cell_iter = cell_struct.workspace .states_range (
116
+ cell_struct.rnn .n_layer - cell_struct.lay ,
117
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
118
+ cell_struct.dir , wci_offset, wci_offset);
119
+
120
+ auto diff_gates = cell_struct.scratch .diff_gates (0 );
121
+
122
+ CHECK (rnn_bias (cell_struct.ctx , cell_struct.rnn .mb , cell_struct.rnn .dhc ,
123
+ cell_struct.iter , cell_struct.lay , cell_struct.dir ,
124
+ cell_struct.rnn .n_layer , diff_cell_layer, diff_cell_iter,
125
+ cell_struct.user_data , ws_gates, diff_gates));
126
+
127
+ auto dsi_offset = cell_struct.dir == 1
128
+ ? cell_struct.iter + 1
129
+ : cell_struct.rnn .n_iter - cell_struct.iter ;
130
+
131
+ auto diff_states_layer = cell_struct.scratch .diff_states (
132
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir ,
133
+ dsi_offset, 0 );
134
+
135
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , wei_iter,
136
+ diff_gates, diff_states_layer, matmul_iter_bwd));
137
+
138
+ auto diff_states_iter = cell_struct.scratch .diff_states (
139
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir ,
140
+ dsi_offset, 1 );
141
+
142
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , wei_layer,
143
+ diff_gates, diff_states_iter, matmul_layer_bwd));
144
+
145
+ auto diff_wei_layer = cell_struct.user_data .diff_wei_layer (
146
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
147
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , diff_gates,
148
+ wei_cell_layer, diff_wei_layer, matmul_diff_wei_layer));
149
+
150
+ auto diff_wei_iter = cell_struct.user_data .diff_wei_iter (
151
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
152
+ CHECK (matmul_primitive (cell_struct.engine , cell_struct.ctx , diff_gates,
153
+ wei_cell_iter, diff_wei_iter, matmul_diff_wei_iter));
154
+
155
+ return status::success;
156
+ }
157
+
64
158
} // namespace sycl
65
159
} // namespace generic
66
160
} // namespace gpu
0 commit comments