@@ -27,7 +27,9 @@ namespace sycl {
27
27
using namespace dnnl ::impl::utils;
28
28
using namespace rnn_utils ;
29
29
30
- status_t _ref_rnn_common_t::cell_execution (const cell_ctx_t &cell_struct) {
30
+ template <prop_kind_t aprop>
31
+ status_t _ref_rnn_common_t <aprop>::cell_execution(
32
+ const cell_ctx_t &cell_struct) {
31
33
32
34
auto cell_layer = cell_struct.workspace .states_range (cell_struct.lay ,
33
35
cell_struct.lay , cell_struct.dir , cell_struct.dir , cell_struct.iter ,
@@ -48,19 +50,114 @@ status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
48
50
auto wei_iter
49
51
= cell_struct.user_data .wei_iter (cell_struct.lay , cell_struct.dir );
50
52
51
- CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_layer,
52
- cell_layer, scratch_gates, gemm_layer_fwd));
53
+ if (aprop == prop_kind::forward) {
53
54
54
- CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_iter ,
55
- cell_iter , scratch_gates, gemm_iter_fwd ));
55
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_layer ,
56
+ cell_layer , scratch_gates, gemm_layer_fwd ));
56
57
57
- CHECK (rnn_bias (cell_struct.ctx , cell_struct.rnn .mb , cell_struct.rnn .dhc ,
58
- cell_struct.iter , cell_struct.lay , cell_struct.dir ,
59
- cell_struct.workspace , cell_struct.scratch , cell_struct.user_data ));
58
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_iter,
59
+ cell_iter, scratch_gates, gemm_iter_fwd));
60
+
61
+ CHECK (rnn_bias_fwd (cell_struct.ctx , cell_struct.rnn .mb ,
62
+ cell_struct.rnn .dhc , cell_struct.iter , cell_struct.lay ,
63
+ cell_struct.dir , cell_struct.workspace , cell_struct.scratch ,
64
+ cell_struct.user_data ));
65
+ } else { // backward
66
+
67
+ wei_layer = cell_struct.user_data .wei_layer (
68
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
69
+ wei_iter = cell_struct.user_data .wei_iter (
70
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
71
+
72
+ auto ws_gates = cell_struct.workspace .gates (
73
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
74
+ cell_struct.rnn .n_iter - cell_struct.iter - 1 );
75
+
76
+ auto dsl_lay_off = cell_struct.iter == 0
77
+ ? cell_struct.rnn .n_layer * 2 - cell_struct.lay - 1
78
+ : cell_struct.rnn .n_layer - cell_struct.lay - 1 ;
79
+
80
+ auto dsl_iter_offset = cell_struct.iter == 0 ? 0
81
+ : cell_struct.dir == 1
82
+ ? cell_struct.iter
83
+ : cell_struct.rnn .n_iter - cell_struct.iter + 1 ;
84
+ auto diff_cell_layer = cell_struct.scratch .diff_states (
85
+ dsl_lay_off, cell_struct.dir , dsl_iter_offset);
86
+
87
+ auto dir_off = cell_struct.dir == 1
88
+ ? cell_struct.iter + 1
89
+ : cell_struct.rnn .n_iter - cell_struct.iter ;
90
+ auto diff_cell_iter = cell_struct.scratch .diff_states (
91
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
92
+ dir_off);
93
+
94
+ auto wei_cell_layer = cell_struct.workspace .states_range (
95
+ cell_struct.rnn .n_layer - 1 - cell_struct.lay ,
96
+ cell_struct.rnn .n_layer - 1 - cell_struct.lay , cell_struct.dir ,
97
+ cell_struct.dir , cell_struct.rnn .n_iter - cell_struct.iter - 1 ,
98
+ cell_struct.rnn .n_iter - cell_struct.iter - 1 );
99
+
100
+ auto wci_offset = cell_struct.rnn .n_iter - cell_struct.iter - 2 ;
101
+
102
+ if (cell_struct.rnn .n_dir == 2 ) {
103
+ wci_offset = cell_struct.rnn .n_iter - cell_struct.iter - 1 == 0
104
+ ? cell_struct.rnn .n_iter - cell_struct.iter
105
+ - cell_struct.rnn .n_iter - 3
106
+ : cell_struct.rnn .n_iter - cell_struct.iter - 2 ;
107
+ }
108
+
109
+ auto wei_cell_iter = cell_struct.workspace .states_range (
110
+ cell_struct.rnn .n_layer - cell_struct.lay ,
111
+ cell_struct.rnn .n_layer - cell_struct.lay , cell_struct.dir ,
112
+ cell_struct.dir , wci_offset, wci_offset);
113
+
114
+ auto diff_gates = cell_struct.scratch .diff_gates (0 );
115
+
116
+ CHECK (rnn_bias_bwd (cell_struct.ctx , cell_struct.rnn .mb ,
117
+ cell_struct.rnn .dhc , cell_struct.iter , cell_struct.lay ,
118
+ cell_struct.dir , cell_struct.rnn .n_layer , diff_cell_layer,
119
+ diff_cell_iter, cell_struct.user_data , ws_gates, diff_gates));
120
+
121
+ auto diff_states_1ay = cell_struct.scratch .diff_states (
122
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir ,
123
+ cell_struct.rnn .n_iter - cell_struct.iter );
124
+ auto dsi_offset = cell_struct.dir == 1
125
+ ? cell_struct.iter + 1
126
+ : cell_struct.rnn .n_iter - cell_struct.iter ;
127
+
128
+ auto diff_states_layer = cell_struct.scratch .diff_states (
129
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir ,
130
+ dsi_offset, 0 );
131
+
132
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_iter,
133
+ diff_gates, diff_states_layer, gemm_iter_bwd));
134
+
135
+ auto diff_states_iter = cell_struct.scratch .diff_states (
136
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir ,
137
+ dsi_offset, 1 );
138
+
139
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , wei_layer,
140
+ diff_gates, diff_states_iter, gemm_layer_bwd));
141
+
142
+ auto diff_wei_layer = cell_struct.user_data .diff_wei_layer (
143
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
144
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , diff_gates,
145
+ wei_cell_layer, diff_wei_layer, gemm_diff_wei_layer));
146
+
147
+ auto diff_wei_iter = cell_struct.user_data .diff_wei_iter (
148
+ cell_struct.rnn .n_layer - cell_struct.lay - 1 , cell_struct.dir );
149
+ CHECK (gemm_primitive (cell_struct.engine , cell_struct.ctx , diff_gates,
150
+ wei_cell_iter, diff_wei_iter, gemm_diff_wei_iter));
151
+ }
60
152
61
153
return status::success;
62
154
}
63
155
156
+ template status_t _ref_rnn_common_t <prop_kind::forward>::cell_execution(
157
+ const cell_ctx_t &cell_struct);
158
+ template status_t _ref_rnn_common_t <prop_kind::backward>::cell_execution(
159
+ const cell_ctx_t &cell_struct);
160
+
64
161
} // namespace sycl
65
162
} // namespace generic
66
163
} // namespace gpu
0 commit comments