Skip to content

Commit cf1b599

Browse files
committed
Start fixing callers that relied on IntImm -> runtime::Int
1 parent c55d87c commit cf1b599

File tree

14 files changed

+82
-73
lines changed

14 files changed

+82
-73
lines changed

include/tvm/meta_schedule/schedule/cuda/thread_bind.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace meta_schedule {
3636
* \return A sampler that returns a random thread extent.
3737
*/
3838
std::function<tir::ExprRV(int64_t)> MakeFactorSampler(tir::Schedule sch,
39-
Array<Integer> thread_extents);
39+
Array<runtime::Int> thread_extents);
4040

4141
/*!
4242
* \brief Bind blockIdx.x and threadIdx.x to the given loop

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ class ScheduleRule : public runtime::ObjectRef {
154154
* ignored by default. This function should return True for a block that should be tiled.
155155
* \return The schedule rule created
156156
*/
157-
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
158-
Optional<Array<String>> tile_binds, //
159-
Optional<Integer> max_innermost_factor, //
160-
Optional<Array<Integer>> vector_load_lens, //
161-
Optional<Map<String, ObjectRef>> reuse_read, //
157+
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
158+
Optional<Array<String>> tile_binds, //
159+
Optional<runtime::Int> max_innermost_factor, //
160+
Optional<Array<runtime::Int>> vector_load_lens, //
161+
Optional<Map<String, ObjectRef>> reuse_read, //
162162
Optional<Map<String, ObjectRef>> reuse_write,
163163
Optional<runtime::PackedFunc> filter_fn = NullOpt);
164164

@@ -181,7 +181,7 @@ class ScheduleRule : public runtime::ObjectRef {
181181
*/
182182
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
183183
String intrin_name, String structure, Optional<Array<String>> tile_binds,
184-
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
184+
Optional<runtime::Int> max_innermost_factor, Optional<Array<runtime::Int>> vector_load_lens,
185185
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
186186

187187
/*!
@@ -206,8 +206,8 @@ class ScheduleRule : public runtime::ObjectRef {
206206
*/
207207
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
208208
Array<Map<String, String>> intrin_groups, String structure,
209-
Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
210-
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
209+
Optional<Array<String>> tile_binds, Optional<runtime::Int> max_innermost_factor,
210+
Optional<Array<runtime::Int>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
211211
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
212212

213213
/*!
@@ -222,8 +222,9 @@ class ScheduleRule : public runtime::ObjectRef {
222222
* \return The schedule rule created
223223
*/
224224
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
225-
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
226-
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
225+
String structure, runtime::Int vector_length_in_bits,
226+
Optional<runtime::Int> max_innermost_factor, Optional<Map<String, ObjectRef>> reuse_read,
227+
Optional<Map<String, ObjectRef>> reuse_write);
227228

228229
/*!
229230
* \brief Create a rule: add-rfactor to some blocks if needed
@@ -234,7 +235,7 @@ class ScheduleRule : public runtime::ObjectRef {
234235
* \return The schedule rule created
235236
*/
236237
TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
237-
Optional<Integer> max_innermost_factor);
238+
Optional<runtime::Int> max_innermost_factor);
238239
/*!
239240
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
240241
* correspondingly when needed
@@ -272,7 +273,7 @@ class ScheduleRule : public runtime::ObjectRef {
272273
* when this schedule rule is created.
273274
* \return The schedule rule created
274275
*/
275-
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents,
276+
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<runtime::Int> thread_extents,
276277
int max_threads_per_block = -1);
277278
/*!
278279
* \brief Create a schedule rule with customized methods on the python-side.

include/tvm/relay/attrs/transform.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
530530
}
531531
}; // struct MatrixSetDiagAttrs
532532

533+
template <typename U, typename T>
534+
using Identity = T;
535+
533536
/*! \brief Attributes used in cumsum and cumprod operator */
534537
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
535538
Integer axis;
@@ -542,7 +545,7 @@ struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
542545
// Default is 0 which is "false"
543546
TVM_ATTR_FIELD(exclusive)
544547
.describe("The first element is not included")
545-
.set_default(Bool(false));
548+
.set_default(Identity<FVisit, Bool>(false));
546549
}
547550
}; // struct ScanopAttrs
548551

src/meta_schedule/schedule/cuda/thread_bind.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ namespace meta_schedule {
3131

3232
using namespace tvm::tir;
3333

34-
std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<Integer> thread_extents) {
34+
std::function<ExprRV(int64_t)> MakeFactorSampler(Schedule sch, Array<runtime::Int> thread_extents) {
3535
return [sch = std::move(sch),
3636
thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV {
3737
Array<runtime::Int> extents;
3838
extents.reserve(thread_extents.size());
39-
for (const Integer extent : thread_extents) {
39+
for (auto extent : thread_extents) {
4040
if (extent->value <= max_extent) {
41-
extents.push_back(runtime::Int(extent->value));
41+
extents.push_back(extent);
4242
}
4343
}
4444
int n = extents.size();
@@ -64,7 +64,7 @@ Array<LoopRV> BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock
6464
}
6565
if (extent <= max_threadblocks * max_threads_per_block) {
6666
if (!get_factor) {
67-
get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
67+
get_factor = MakeFactorSampler(sch, Array<runtime::Int>{32, 64, 128, 256, 512, 1024});
6868
}
6969
ExprRV factor = get_factor(std::min(extent, max_threads_per_block));
7070
Array<LoopRV> splits = sch->Split(loop, {NullOpt, factor});

src/meta_schedule/schedule_rule/add_rfactor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ class AddRFactorNode : public ScheduleRuleNode {
6868
};
6969

7070
ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
71-
Optional<Integer> max_innermost_factor) {
71+
Optional<runtime::Int> max_innermost_factor) {
7272
ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
7373
n->max_jobs_per_core = max_jobs_per_core;
74-
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
74+
n->max_innermost_factor = max_innermost_factor.value_or(runtime::Int(-1))->value;
7575
n->max_parallel_extent_ = -1;
7676
n->max_parallel_basic_ = -1;
7777
return ScheduleRule(n);

src/meta_schedule/schedule_rule/auto_bind.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ class AutoBindNode : public ScheduleRuleNode {
3131
// Inherited from ScheduleRuleNode
3232
void InitializeWithTuneContext(const TuneContext& context) final {
3333
CHECK(context->target.defined()) << "ValueError: target is not defined";
34-
Optional<Integer> max_threads_per_block =
35-
context->target.value()->GetAttr<Integer>("max_threads_per_block");
34+
Optional<runtime::Int> max_threads_per_block =
35+
context->target.value()->GetAttr<runtime::Int>("max_threads_per_block");
3636
CHECK(max_threads_per_block.defined())
3737
<< "ValueError: missing attribute `max_threads_per_block` in the target";
38-
this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
38+
this->max_threads_per_block_ = max_threads_per_block.value();
3939
}
4040

4141
// Inherited from ScheduleRuleNode
@@ -53,7 +53,7 @@ class AutoBindNode : public ScheduleRuleNode {
5353
/*! \brief The max number of threadblocks in the cuda device */
5454
int64_t max_threadblocks_ = -1;
5555
/*! \brief thread_extents Candidates of thread axis extent. */
56-
Array<Integer> thread_extents_;
56+
Array<runtime::Int> thread_extents_;
5757

5858
void VisitAttrs(tvm::AttrVisitor* v) {
5959
// `max_threads_per_block_` is not visited
@@ -72,7 +72,7 @@ Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl
7272
return {sch};
7373
}
7474

75-
ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents,
75+
ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<runtime::Int> thread_extents,
7676
int max_threads_per_block) {
7777
ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>();
7878
n->max_threadblocks_ = max_threadblocks;

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
392392
// Constructor
393393

394394
ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
395-
Optional<Integer> max_innermost_factor,
396-
Optional<Array<Integer>> vector_load_lens,
395+
Optional<runtime::Int> max_innermost_factor,
396+
Optional<Array<runtime::Int>> vector_load_lens,
397397
Optional<Map<String, ObjectRef>> reuse_read,
398398
Optional<Map<String, ObjectRef>> reuse_write,
399399
Optional<runtime::PackedFunc> filter_fn) {

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,16 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
239239

240240
template <typename NodeType>
241241
ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
242-
Optional<Integer> max_innermost_factor,
243-
Optional<Array<Integer>> vector_load_lens,
242+
Optional<runtime::Int> max_innermost_factor,
243+
Optional<Array<runtime::Int>> vector_load_lens,
244244
Optional<Map<String, ObjectRef>> reuse_read,
245245
Optional<Map<String, ObjectRef>> reuse_write) {
246246
ObjectPtr<NodeType> n = make_object<NodeType>();
247247
n->structure = structure;
248248
n->tile_binds = tile_binds.value_or({});
249-
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
249+
n->max_innermost_factor = max_innermost_factor.value_or(runtime::Int(-1))->value;
250250
n->vector_load_lens = vector_load_lens.defined()
251-
? support::AsVector<Integer, int>(vector_load_lens.value())
251+
? support::AsVector<runtime::Int, int>(vector_load_lens.value())
252252
: std::vector<int>();
253253
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
254254
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();

src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
890890

891891
ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
892892
Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
893-
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
893+
Optional<runtime::Int> max_innermost_factor, Optional<Array<runtime::Int>> vector_load_lens,
894894
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
895895
bool use_software_pipeline) {
896896
if (tile_binds.defined()) {

src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingWideVectorNode
113113
}
114114

115115
ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
116-
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
117-
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
116+
String structure, runtime::Int vector_length_in_bits,
117+
Optional<runtime::Int> max_innermost_factor, Optional<Map<String, ObjectRef>> reuse_read,
118+
Optional<Map<String, ObjectRef>> reuse_write) {
118119
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
119120
structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
120121
node->vector_length_in_bits = vector_length_in_bits->value;

0 commit comments

Comments
 (0)