@@ -125,6 +125,7 @@ class NewToOldOpType
125125 mapping_[OpType::Unsqueeze] = " unsqueeze" ;
126126 mapping_[OpType::UpdateCache] = " update_cache" ;
127127 mapping_[OpType::Upsample2d] = " upsample2d" ;
128+ mapping_[OpType::TopK] = " topk" ;
128129 mapping_[OpType::Where] = " where" ;
129130 }
130131
@@ -227,6 +228,7 @@ class OldToNewOpType
227228 mapping_[" unsqueeze" ] = OpType::Unsqueeze;
228229 mapping_[" update_cache" ] = OpType::UpdateCache;
229230 mapping_[" upsample2d" ] = OpType::Upsample2d;
231+ mapping_[" topk" ] = OpType::TopK;
230232 mapping_[" where" ] = OpType::Where;
231233 }
232234
@@ -394,6 +396,7 @@ at::Tensor Op::eval(const graphlib::OpType &old_op_type, const std::vector<at::T
394396 case OpType::Unsqueeze: return unsqueeze::eval (old_op_type, *this , tensors);
395397 case OpType::UpdateCache: return update_cache::eval (old_op_type, *this , tensors);
396398 case OpType::Upsample2d: return upsample_2d::eval (old_op_type, *this , tensors);
399+ case OpType::TopK: return topk::eval (old_op_type, *this , tensors);
397400 case OpType::Where: return where::eval (old_op_type, *this , tensors);
398401 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
399402 } // clang-format on
@@ -489,6 +492,7 @@ std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcast>> Op::shape(
489492 case OpType::Unsqueeze: return unsqueeze::shape (old_op_type, *this , inputs);
490493 case OpType::UpdateCache: return update_cache::shape (old_op_type, *this , inputs);
491494 case OpType::Upsample2d: return upsample_2d::shape (old_op_type, *this , inputs);
495+ case OpType::TopK: return topk::shape (old_op_type, *this , inputs);
492496 case OpType::Where: return where::shape (old_op_type, *this , inputs);
493497 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
494498 } // clang-format on
@@ -589,6 +593,7 @@ tt::graphlib::NodeContext Op::backward(
589593 case OpType::Unsqueeze: return unsqueeze::backward (old_op_type, *this , context, operand, inputs, output, gradient);
590594 case OpType::UpdateCache: return update_cache::backward (old_op_type, *this , context, operand, inputs, output, gradient);
591595 case OpType::Upsample2d: return upsample_2d::backward (old_op_type, *this , context, operand, inputs, output, gradient);
596+ case OpType::TopK: return topk::backward (old_op_type, *this , context, operand, inputs, output, gradient);
592597 case OpType::Where: return where::backward (old_op_type, *this , context, operand, inputs, output, gradient);
593598 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
594599 } // clang-format on
@@ -706,6 +711,7 @@ void Op::decompose_initial(
706711 case OpType::Unsqueeze: return ;
707712 case OpType::UpdateCache: return ;
708713 case OpType::Upsample2d: return ;
714+ case OpType::TopK: return ;
709715 case OpType::Where: return where::decompose_initial (old_op_type, *this , dc, inputs);
710716 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
711717 } // clang-format on
@@ -802,6 +808,7 @@ void Op::decompose_post_optimize(
802808 case OpType::Unsqueeze: return ;
803809 case OpType::UpdateCache: return ;
804810 case OpType::Upsample2d: return ;
811+ case OpType::TopK: return ;
805812 case OpType::Where: return where::decompose_post_optimize (old_op_type, *this , dc, inputs);
806813 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
807814 } // clang-format on
@@ -899,6 +906,7 @@ void Op::decompose_post_autograd(
899906 case OpType::Unsqueeze: return ;
900907 case OpType::UpdateCache: return ;
901908 case OpType::Upsample2d: return ;
909+ case OpType::TopK: return ;
902910 case OpType::Where: return where::decompose_post_autograd (old_op_type, *this , dc, inputs);
903911 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
904912 } // clang-format on
@@ -994,6 +1002,7 @@ long Op::initial_flops_estimate(
9941002 case OpType::Unsqueeze: return 0 ;
9951003 case OpType::UpdateCache: return 0 ;
9961004 case OpType::Upsample2d: return 0 ;
1005+ case OpType::TopK: return 0 ;
9971006 case OpType::Where: return where::initial_flops_estimate (old_op_type, *this , inputs);
9981007 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
9991008 } // clang-format on
@@ -1088,6 +1097,7 @@ bool Op::is_tm(const graphlib::OpType &old_op_type) const
10881097 case OpType::Unsqueeze: return true ;
10891098 case OpType::UpdateCache: return false ;
10901099 case OpType::Upsample2d: return false ;
1100+ case OpType::TopK: return false ;
10911101 case OpType::Where: return false ;
10921102 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
10931103 }
@@ -1182,6 +1192,7 @@ bool Op::is_eltwise(const graphlib::OpType &old_op_type) const
11821192 case OpType::Unsqueeze: return false ;
11831193 case OpType::UpdateCache: return false ;
11841194 case OpType::Upsample2d: return false ;
1195+ case OpType::TopK: return false ;
11851196 case OpType::Where: return true ;
11861197 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
11871198 }
@@ -1276,6 +1287,7 @@ bool Op::is_eltwise_unary(const graphlib::OpType &old_op_type) const
12761287 case OpType::Unsqueeze: return false ;
12771288 case OpType::UpdateCache: return false ;
12781289 case OpType::Upsample2d: return false ;
1290+ case OpType::TopK: return false ;
12791291 case OpType::Where: return false ;
12801292 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
12811293 }
@@ -1370,6 +1382,7 @@ bool Op::is_eltwise_binary(const graphlib::OpType &old_op_type) const
13701382 case OpType::Unsqueeze: return false ;
13711383 case OpType::UpdateCache: return false ;
13721384 case OpType::Upsample2d: return false ;
1385+ case OpType::TopK: return false ;
13731386 case OpType::Where: return false ;
13741387 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
13751388 }
@@ -1463,6 +1476,7 @@ bool Op::is_eltwise_nary(const graphlib::OpType &old_op_type) const
14631476 case OpType::Unsqueeze: return false ;
14641477 case OpType::UpdateCache: return false ;
14651478 case OpType::Upsample2d: return false ;
1479+ case OpType::TopK: return false ;
14661480 case OpType::Where: return true ;
14671481 default : TT_ASSERT (false , " Unknown OpType." ); unreachable ();
14681482 }
0 commit comments