Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,28 +866,6 @@ struct PackedFuncValueConverter<tvm::FloatImm> {
}
};

/* \brief Backwards compatibility wrapper for IntImm arguments
*
* In previous versions of TVM, IntImm was the default FFI type for
* integer arguments, instead of runtime::Int. For backwards
* compatibility where the callee has been updated to expected a
* runtime::Int, the caller has not been updated to provide a
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
* allow the IntImm to be generated.
*/
template <>
struct PackedFuncValueConverter<runtime::Int> {
template <typename PODSubclass>
static runtime::Int From(const PODSubclass& val) {
if (val.template IsObjectRef<tvm::IntImm>()) {
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
} else {
return val.template AsObjectRef<runtime::Int>();
}
}
};

} // namespace runtime
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/schedule/cuda/thread_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace meta_schedule {
* \return A sampler that returns a random thread extent.
*/
std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
Array<Integer> thread_extents);
Array<runtime::Int> thread_extents);

/*!
* \brief Bind blockIdx.x and threadIdx.x to the given loop
Expand Down
25 changes: 13 additions & 12 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ class ScheduleRule : public runtime::ObjectRef {
* ignored by default. This function should return True for a block that should be tiled.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
Optional<Integer> max_innermost_factor, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
Optional<runtime::Int> max_innermost_factor, //
Optional<Array<runtime::Int>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write,
Optional<runtime::PackedFunc> filter_fn = NullOpt);

Expand All @@ -181,7 +181,7 @@ class ScheduleRule : public runtime::ObjectRef {
*/
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
String intrin_name, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<runtime::Int> max_innermost_factor, Optional<Array<runtime::Int>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
Expand All @@ -206,8 +206,8 @@ class ScheduleRule : public runtime::ObjectRef {
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Array<Map<String, String>> intrin_groups, String structure,
Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Array<String>> tile_binds, Optional<runtime::Int> max_innermost_factor,
Optional<Array<runtime::Int>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);

/*!
Expand All @@ -222,8 +222,9 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
String structure, runtime::Int vector_length_in_bits,
Optional<runtime::Int> max_innermost_factor, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
Expand All @@ -234,7 +235,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
Optional<Integer> max_innermost_factor);
Optional<runtime::Int> max_innermost_factor);
/*!
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
* correspondingly when needed
Expand Down Expand Up @@ -272,7 +273,7 @@ class ScheduleRule : public runtime::ObjectRef {
* when this schedule rule is created.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<runtime::Int> thread_extents,
int max_threads_per_block = -1);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}
}; // struct MatrixSetDiagAttrs

template <typename U, typename T>
using Identity = T;

/*! \brief Attributes used in cumsum and cumprod operator */
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
Integer axis;
Expand All @@ -542,7 +545,7 @@ struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
// Default is 0 which is "false"
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(Bool(false));
.set_default(Identity<FVisit, Bool>(false));
}
}; // struct ScanopAttrs

Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/schedule/cuda/thread_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ namespace meta_schedule {

using namespace tvm::tir;

std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<Integer> thread_extents) {
std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<runtime::Int> thread_extents) {
return [sch = std::move(sch),
thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV {
Array<runtime::Int> extents;
extents.reserve(thread_extents.size());
for (const Integer extent : thread_extents) {
for (auto extent : thread_extents) {
if (extent->value <= max_extent) {
extents.push_back(runtime::Int(extent->value));
extents.push_back(extent);
}
}
int n = extents.size();
Expand All @@ -64,7 +64,7 @@ Array<LoopRV> BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock
}
if (extent <= max_threadblocks * max_threads_per_block) {
if (!get_factor) {
get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
get_factor = MakeFactorSampler(sch, Array<runtime::Int>{32, 64, 128, 256, 512, 1024});
}
ExprRV factor = get_factor(std::min(extent, max_threads_per_block));
Array<LoopRV> splits = sch->Split(loop, {NullOpt, factor});
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/schedule_rule/add_rfactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class AddRFactorNode : public ScheduleRuleNode {
};

ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
Optional<Integer> max_innermost_factor) {
Optional<runtime::Int> max_innermost_factor) {
ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
n->max_jobs_per_core = max_jobs_per_core;
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->max_innermost_factor = max_innermost_factor.value_or(runtime::Int(-1))->value;
n->max_parallel_extent_ = -1;
n->max_parallel_basic_ = -1;
return ScheduleRule(n);
Expand Down
10 changes: 5 additions & 5 deletions src/meta_schedule/schedule_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class AutoBindNode : public ScheduleRuleNode {
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {
CHECK(context->target.defined()) << "ValueError: target is not defined";
Optional<Integer> max_threads_per_block =
context->target.value()->GetAttr<Integer>("max_threads_per_block");
Optional<runtime::Int> max_threads_per_block =
context->target.value()->GetAttr<runtime::Int>("max_threads_per_block");
CHECK(max_threads_per_block.defined())
<< "ValueError: missing attribute `max_threads_per_block` in the target";
this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
this->max_threads_per_block_ = max_threads_per_block.value();
}

// Inherited from ScheduleRuleNode
Expand All @@ -53,7 +53,7 @@ class AutoBindNode : public ScheduleRuleNode {
/*! \brief The max number of threadblocks in the cuda device */
int64_t max_threadblocks_ = -1;
/*! \brief thread_extents Candidates of thread axis extent. */
Array<Integer> thread_extents_;
Array<runtime::Int> thread_extents_;

void VisitAttrs(tvm::AttrVisitor* v) {
// `max_threads_per_block_` is not visited
Expand All @@ -72,7 +72,7 @@ Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl
return {sch};
}

ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents,
ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<runtime::Int> thread_extents,
int max_threads_per_block) {
ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>();
n->max_threadblocks_ = max_threadblocks;
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
// Constructor

ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<runtime::Int> max_innermost_factor,
Optional<Array<runtime::Int>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write,
Optional<runtime::PackedFunc> filter_fn) {
Expand Down
8 changes: 4 additions & 4 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,16 @@ class MultiLevelTilingNode : public ScheduleRuleNode {

template <typename NodeType>
ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<runtime::Int> max_innermost_factor,
Optional<Array<runtime::Int>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<NodeType> n = make_object<NodeType>();
n->structure = structure;
n->tile_binds = tile_binds.value_or({});
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->max_innermost_factor = max_innermost_factor.value_or(runtime::Int(-1))->value;
n->vector_load_lens = vector_load_lens.defined()
? support::AsVector<Integer, int>(vector_load_lens.value())
? support::AsVector<runtime::Int, int>(vector_load_lens.value())
: std::vector<int>();
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat

ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<runtime::Int> max_innermost_factor, Optional<Array<runtime::Int>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
bool use_software_pipeline) {
if (tile_binds.defined()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingWideVectorNode
}

ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
String structure, runtime::Int vector_length_in_bits,
Optional<runtime::Int> max_innermost_factor, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
node->vector_length_in_bits = vector_length_in_bits->value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {

ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(
String intrin_name, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<runtime::Int> max_innermost_factor, Optional<Array<runtime::Int>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
ICHECK(tir::TensorIntrin::Get(intrin_name).defined())
<< "Provided tensor intrinsic " << intrin_name << " is not registered.";
Expand Down
Loading