Skip to content

Commit ebf8a71

Browse files
authored
fix: get tensors by const ref to not rely on deleted move constructor for TensorView (flashinfer-ai#2602)
<!-- .github/pull_request_template.md --> ## 📌 Description Getting a `Tensor` our of an `Array` as a `TensorView` attempts to call the deleted move-constructor `TensorView(Tensor&&)`. We can instead get a const ref to the tensors out of the array. Error messages before these changes: ``` /workspace/flashinfer/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu(980): error: function "tvm::ffi::TensorView::TensorView(tvm::ffi::Tensor &&)" (declared at line 717 of /workspace/venv/lib/python3.12/site-packages/tvm_ffi/include/tvm/ffi/container/tensor.h) cannot be referenced -- it is a deleted function TensorView fc1_global = quant_scales.value()[1]; ^ ``` No errors after the fix. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Internal optimization to quantization handling in the fused mixture of experts module for improved code efficiency. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 80f4de4 commit ebf8a71

1 file changed

Lines changed: 22 additions & 22 deletions

File tree

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -976,10 +976,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
976976
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
977977
"Expecting 4 quant scales for W4A8_MXFP4_MXFP8 quantization";
978978

979-
TensorView fc1_weight_block = quant_scales.value()[0];
980-
TensorView fc1_global = quant_scales.value()[1];
981-
TensorView fc2_weight_block = quant_scales.value()[2];
982-
TensorView fc2_global = quant_scales.value()[3];
979+
auto const& fc1_weight_block = quant_scales.value()[0];
980+
auto const& fc1_global = quant_scales.value()[1];
981+
auto const& fc2_weight_block = quant_scales.value()[2];
982+
auto const& fc2_global = quant_scales.value()[3];
983983

984984
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
985985
constexpr int FP8_PER_INT32 = 4;
@@ -1035,12 +1035,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
10351035
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 6)
10361036
<< "Expecting 6 quant scales for nvfp4 quantization";
10371037

1038-
TensorView fc1_act_global = quant_scales.value()[0];
1039-
TensorView fc1_weight_block = quant_scales.value()[1];
1040-
TensorView fc1_global = quant_scales.value()[2];
1041-
TensorView fc2_act_global = quant_scales.value()[3];
1042-
TensorView fc2_weight_block = quant_scales.value()[4];
1043-
TensorView fc2_global = quant_scales.value()[5];
1038+
auto const& fc1_act_global = quant_scales.value()[0];
1039+
auto const& fc1_weight_block = quant_scales.value()[1];
1040+
auto const& fc1_global = quant_scales.value()[2];
1041+
auto const& fc2_act_global = quant_scales.value()[3];
1042+
auto const& fc2_weight_block = quant_scales.value()[4];
1043+
auto const& fc2_global = quant_scales.value()[5];
10441044

10451045
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
10461046
constexpr int FP8_PER_INT32 = 4;
@@ -1118,8 +1118,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
11181118
static_cast<float const*>(fc2_global.data_ptr()), fc1_act_global.ndim() == 1,
11191119
fc2_act_global.ndim() == 1);
11201120
} else if (mUseDeepSeekFP8BlockScaling) {
1121-
TensorView fc1_scales = quant_scales.value()[0];
1122-
TensorView fc2_scales = quant_scales.value()[1];
1121+
auto const& fc1_scales = quant_scales.value()[0];
1122+
auto const& fc2_scales = quant_scales.value()[1];
11231123
return kernels::QuantParams::FP8BlockScaling(
11241124
static_cast<float const*>(fc1_scales.data_ptr()),
11251125
static_cast<float const*>(fc2_scales.data_ptr()));
@@ -1128,8 +1128,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
11281128
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 2)
11291129
<< "Expecting 2 quant scales for W4A16 quantization";
11301130

1131-
TensorView fc1_weight_scales = quant_scales.value()[0];
1132-
TensorView fc2_weight_scales = quant_scales.value()[1];
1131+
auto const& fc1_weight_scales = quant_scales.value()[0];
1132+
auto const& fc2_weight_scales = quant_scales.value()[1];
11331133
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size;
11341134
return kernels::QuantParams::GroupWise(group_size,
11351135
static_cast<void const*>(fc1_weight_scales.data_ptr()),
@@ -1139,14 +1139,14 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
11391139
TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for INT4 quantization";
11401140
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 8)
11411141
<< "Expecting 8 quant scales for INT4 quantization";
1142-
TensorView fc1_weight_scales = quant_scales.value()[0];
1143-
TensorView fc2_weight_scales = quant_scales.value()[1];
1144-
TensorView fc1_act_scales = quant_scales.value()[2];
1145-
TensorView fc2_act_scales = quant_scales.value()[3];
1146-
TensorView fc1_weight_zeros = quant_scales.value()[4];
1147-
TensorView fc2_weight_zeros = quant_scales.value()[5];
1148-
TensorView fc1_alpha = quant_scales.value()[6];
1149-
TensorView fc2_alpha = quant_scales.value()[7];
1142+
auto const& fc1_weight_scales = quant_scales.value()[0];
1143+
auto const& fc2_weight_scales = quant_scales.value()[1];
1144+
auto const& fc1_act_scales = quant_scales.value()[2];
1145+
auto const& fc2_act_scales = quant_scales.value()[3];
1146+
auto const& fc1_weight_zeros = quant_scales.value()[4];
1147+
auto const& fc2_weight_zeros = quant_scales.value()[5];
1148+
auto const& fc1_alpha = quant_scales.value()[6];
1149+
auto const& fc2_alpha = quant_scales.value()[7];
11501150
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
11511151
return kernels::QuantParams::GroupWise(
11521152
group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),

0 commit comments

Comments
 (0)