@@ -31,13 +31,15 @@ enum MapperCallIDs {
31
31
32
32
static Logger log_map (" fuzz_mapper" );
33
33
34
- FuzzMapper::FuzzMapper (MapperRuntime *rt, Machine machine, Processor local, RngStream st)
34
+ FuzzMapper::FuzzMapper (MapperRuntime *rt, Machine machine, Processor local, RngStream st,
35
+ uint64_t replicate)
35
36
: NullMapper(rt, machine),
36
37
stream (st),
37
38
select_tasks_to_map_channel(st.make_channel(int32_t (SELECT_TASKS_TO_MAP))),
38
39
map_inline_channel(st.make_channel(int32_t (MAP_INLINE))),
39
40
select_inline_sources_channel(st.make_channel(int32_t (SELECT_INLINE_SOURCES))),
40
- local_proc(local) {
41
+ local_proc(local),
42
+ replicate_levels(replicate) {
41
43
// TODO: something other than CPU processor
42
44
{
43
45
Machine::ProcessorQuery query (machine);
@@ -85,7 +87,7 @@ void FuzzMapper::select_task_options(const MapperContext ctx, const Task &task,
85
87
output.map_locally = false ; // TODO
86
88
output.valid_instances = false ;
87
89
output.memoize = true ;
88
- output.replicate = task.get_depth () == 0 ; // TODO: replicate other tasks
90
+ output.replicate = task.get_depth () < static_cast < int64_t >(replicate_levels);
89
91
// output.parent_priority = ...; // Leave parent at current priority.
90
92
// output.check_collective_regions.insert(...); // TODO
91
93
}
@@ -124,7 +126,11 @@ void FuzzMapper::map_task(const MapperContext ctx, const Task &task,
124
126
log_map.debug () << " map_task: Selected variant " << output.chosen_variant ;
125
127
126
128
// TODO: assign to variant's correct processor kind
127
- if (rng.uniform_range (0 , 1 ) == 0 ) {
129
+ output.target_procs .clear ();
130
+ if (input.shard_processor .exists ()) {
131
+ log_map.debug () << " map_task: Mapping to shard proc" ;
132
+ output.target_procs .push_back (input.shard_processor );
133
+ } else if (rng.uniform_range (0 , 1 ) == 0 ) {
128
134
log_map.debug () << " map_task: Mapping to all local procs" ;
129
135
output.target_procs .insert (output.target_procs .end (), local_procs.begin (),
130
136
local_procs.end ());
@@ -149,6 +155,8 @@ void FuzzMapper::map_task(const MapperContext ctx, const Task &task,
149
155
void FuzzMapper::replicate_task (MapperContext ctx, const Task &task,
150
156
const ReplicateTaskInput &input,
151
157
ReplicateTaskOutput &output) {
158
+ if (task.get_depth () >= static_cast <int64_t >(replicate_levels)) return ;
159
+
152
160
// TODO: cache this?
153
161
std::vector<VariantID> variants;
154
162
runtime->find_valid_variants (ctx, task.task_id , variants);
@@ -158,7 +166,28 @@ void FuzzMapper::replicate_task(MapperContext ctx, const Task &task,
158
166
abort ();
159
167
}
160
168
output.chosen_variant = variants.at (0 );
161
- // TODO: actually replicate
169
+
170
+ bool is_replicable =
171
+ runtime->is_replicable_variant (ctx, task.task_id , output.chosen_variant );
172
+ // For now assume we always have replicable variants at this level.
173
+ if (!is_replicable) {
174
+ log_map.fatal () << " Bad variants in replicate_task: variant is not replicable" ;
175
+ abort ();
176
+ }
177
+
178
+ std::map<AddressSpace, Processor> targets;
179
+ for (Processor proc : global_procs) {
180
+ AddressSpace space = proc.address_space ();
181
+ if (!targets.count (space)) {
182
+ targets[space] = proc;
183
+ }
184
+ }
185
+
186
+ if (targets.size () > 1 ) {
187
+ for (auto &target : targets) {
188
+ output.target_processors .push_back (target.second );
189
+ }
190
+ }
162
191
}
163
192
164
193
void FuzzMapper::select_task_sources (const MapperContext ctx, const Task &task,
@@ -168,6 +197,13 @@ void FuzzMapper::select_task_sources(const MapperContext ctx, const Task &task,
168
197
random_sources (rng, input.source_instances , output.chosen_ranking );
169
198
}
170
199
200
+ void FuzzMapper::select_sharding_functor (const MapperContext ctx, const Task &task,
201
+ const SelectShardingFunctorInput &input,
202
+ SelectShardingFunctorOutput &output) {
203
+ // TODO: customize the sharding functor
204
+ output.chosen_functor = 0 ;
205
+ }
206
+
171
207
void FuzzMapper::map_inline (const MapperContext ctx, const InlineMapping &inline_op,
172
208
const MapInlineInput &input, MapInlineOutput &output) {
173
209
RngChannel &rng = map_inline_channel;
0 commit comments