@@ -202,11 +202,12 @@ class Op
202202
203203 const std::string &as_string () const ;
204204
205- /* ------------------------------------------------------------ *
206- * Calculations segment. Derived classes must implement these. *
207- * ------------------------------------------------------------ */
205+ /* ----------------------------------------------------*
206+ * Calculations segment. All ops must implement these. *
207+ * ----------------------------------------------------*/
208208
209209 at::Tensor eval (const graphlib::OpType &old_op_type, const std::vector<at::Tensor> &tensors) const ;
210+
210211 std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> shape (
211212 const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t >> &inputs) const ;
212213
@@ -218,6 +219,16 @@ class Op
218219 const tt::graphlib::NodeContext &output,
219220 const tt::graphlib::NodeContext &gradient) const ;
220221
222+ bool is_tm (const graphlib::OpType &old_op_type) const ;
223+ bool is_eltwise (const graphlib::OpType &old_op_type) const ;
224+ bool is_eltwise_unary (const graphlib::OpType &old_op_type) const ;
225+ bool is_eltwise_binary (const graphlib::OpType &old_op_type) const ;
226+ bool is_eltwise_nary (const graphlib::OpType &old_op_type) const ;
227+
228+ /* --------------------------*
229+ * Optional implementations. *
230+ * --------------------------*/
231+
221232 void decompose (
222233 const graphlib::OpType &old_op_type,
223234 const char *dispatch,
@@ -227,18 +238,13 @@ class Op
227238 long initial_flops_estimate (
228239 const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t >> &inputs) const ;
229240
230- bool is_tm (const graphlib::OpType &old_op_type) const ;
231- bool is_eltwise (const graphlib::OpType &old_op_type) const ;
232- bool is_eltwise_unary (const graphlib::OpType &old_op_type) const ;
233- bool is_eltwise_binary (const graphlib::OpType &old_op_type) const ;
234- bool is_eltwise_nary (const graphlib::OpType &old_op_type) const ;
235-
236241 private:
237242 /* ------------------------------------------------------------*
238243 * Base - common for all ops that are not yet migrated to cpp. *
239244 * ------------------------------------------------------------*/
240245
241246 at::Tensor base_eval (const graphlib::OpType &old_op_type, const std::vector<at::Tensor> &tensors) const ;
247+
242248 std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> base_shape (
243249 const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t >> &inputs) const ;
244250
@@ -250,6 +256,12 @@ class Op
250256 const tt::graphlib::NodeContext &output,
251257 const tt::graphlib::NodeContext &gradient) const ;
252258
259+ bool base_is_tm (const graphlib::OpType &old_op_type) const ;
260+ bool base_is_eltwise (const graphlib::OpType &old_op_type) const ;
261+ bool base_is_eltwise_unary (const graphlib::OpType &old_op_type) const ;
262+ bool base_is_eltwise_binary (const graphlib::OpType &old_op_type) const ;
263+ bool base_is_eltwise_nary (const graphlib::OpType &old_op_type) const ;
264+
253265 void base_decompose (
254266 const graphlib::OpType &old_op_type,
255267 const char *dispatch,
@@ -259,12 +271,6 @@ class Op
259271 long base_initial_flops_estimate (
260272 const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t >> &inputs) const ;
261273
262- bool base_is_tm (const graphlib::OpType &old_op_type) const ;
263- bool base_is_eltwise (const graphlib::OpType &old_op_type) const ;
264- bool base_is_eltwise_unary (const graphlib::OpType &old_op_type) const ;
265- bool base_is_eltwise_binary (const graphlib::OpType &old_op_type) const ;
266- bool base_is_eltwise_nary (const graphlib::OpType &old_op_type) const ;
267-
268274 /* -----------------------------*
269275 * Ops specific implementation. *
270276 * -----------------------------*/
@@ -274,6 +280,7 @@ class Op
274280 * -------------*/
275281
276282 at::Tensor abs_eval (const std::vector<at::Tensor> &tensors) const ;
283+
277284 std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> abs_shape (
278285 const std::vector<std::vector<std::uint32_t >> &inputs) const ;
279286
@@ -286,6 +293,22 @@ class Op
286293
287294 long abs_initial_flops_estimate (const std::vector<std::vector<std::uint32_t >> &inputs) const ;
288295
296+ /* ------------------*
297+ * OpType::Constant. *
298+ * ------------------*/
299+
300+ at::Tensor constant_eval (const std::vector<at::Tensor> &tensors) const ;
301+
302+ std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> constant_shape (
303+ const std::vector<std::vector<std::uint32_t >> &inputs) const ;
304+
305+ tt::graphlib::NodeContext constant_backward (
306+ tt::autograd::autograd_context &context,
307+ int operand,
308+ const std::vector<tt::graphlib::NodeContext> &inputs,
309+ const tt::graphlib::NodeContext &output,
310+ const tt::graphlib::NodeContext &gradient) const ;
311+
289312 private:
290313 OpType type_;
291314 Attrs attrs_;
0 commit comments