@@ -152,166 +152,6 @@ LoopUnswitching::CopyPredicateNodes(
152152 }
153153}
154154
155- rvsdg::SubstitutionMap
156- LoopUnswitching::handleGammaExitRegion (
157- const ThetaGammaPredicateCorrelation & correlation,
158- rvsdg::GammaNode & newGammaNode,
159- const rvsdg::SubstitutionMap & substitutionMap)
160- {
161- const auto & oldThetaNode = correlation.thetaNode ();
162- const auto & oldGammaNode = correlation.gammaNode ();
163- const auto [_, oldExitSubregion] = determineGammaSubregionRoles (correlation).value ();
164- const auto newExitSubregion = newGammaNode.subregion (oldExitSubregion->index ());
165- const auto exitSubregionIndex = oldExitSubregion->index ();
166-
167- // Setup substitution map for exit region copying
168- rvsdg::SubstitutionMap exitSubregionMap;
169- for (const auto & [oldInput, oldBranchArgument] : oldGammaNode.GetEntryVars ())
170- {
171- rvsdg::Output * newGammaInputOrigin = nullptr ;
172- auto & oldGammaInputOrigin = *oldInput->origin ();
173-
174- if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(oldGammaInputOrigin))
175- {
176- const auto oldLoopVar = oldThetaNode.MapPreLoopVar (oldGammaInputOrigin);
177- newGammaInputOrigin = oldLoopVar.input ->origin ();
178- }
179- else if (std::holds_alternative<rvsdg::Node *>(oldGammaInputOrigin.GetOwner ()))
180- {
181- // The origin is from one of the predicate nodes
182- const auto substitute = &substitutionMap.lookup (oldGammaInputOrigin);
183- newGammaInputOrigin = substitute;
184- }
185- else
186- {
187- throw std::logic_error (" This should have never happened!" );
188- }
189-
190- auto [_, newBranchArgument] = newGammaNode.AddEntryVar (newGammaInputOrigin);
191- exitSubregionMap.insert (
192- oldBranchArgument[exitSubregionIndex],
193- newBranchArgument[exitSubregionIndex]);
194- }
195-
196- oldExitSubregion->copy (newExitSubregion, exitSubregionMap);
197-
198- // Update substitution map for insertion of exit variables
199- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
200- {
201- const auto output = oldLoopVar.post ->origin ();
202- auto [oldBranchResult, _] = oldGammaNode.MapOutputExitVar (*output);
203- const auto substitute =
204- &exitSubregionMap.lookup (*oldBranchResult[exitSubregionIndex]->origin ());
205- exitSubregionMap.insert (output, substitute);
206- }
207-
208- return exitSubregionMap;
209- }
210-
211- rvsdg::SubstitutionMap
212- LoopUnswitching::handleGammaRepetitionRegion (
213- const ThetaGammaPredicateCorrelation & correlation,
214- rvsdg::GammaNode & newGammaNode,
215- const std::vector<std::vector<rvsdg::Node *>> & predicateNodes,
216- const rvsdg::SubstitutionMap & substitutionMap)
217- {
218- const auto & oldThetaNode = correlation.thetaNode ();
219- const auto & oldGammaNode = correlation.gammaNode ();
220- const auto [oldRepetitionSubregion, oldExitSubregion] =
221- determineGammaSubregionRoles (correlation).value ();
222- const auto & newRepetitionSubregion = newGammaNode.subregion (oldRepetitionSubregion->index ());
223- const auto repetitionSubregionIndex = oldRepetitionSubregion->index ();
224- const auto exitSubregionIndex = oldExitSubregion->index ();
225- const auto newThetaNode = rvsdg::ThetaNode::create (newRepetitionSubregion);
226-
227- // Add loop variables to new theta node and setup substitution map
228- rvsdg::SubstitutionMap repetitionSubregionMap;
229- std::unordered_map<rvsdg::Input *, rvsdg::ThetaNode::LoopVar> newLoopVars;
230- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
231- {
232- auto [_, branchArgument] = newGammaNode.AddEntryVar (oldLoopVar.input ->origin ());
233- auto newLoopVar = newThetaNode->AddLoopVar (branchArgument[repetitionSubregionIndex]);
234- repetitionSubregionMap.insert (oldLoopVar.pre , newLoopVar.pre );
235- newLoopVars[oldLoopVar.input ] = newLoopVar;
236- }
237- for (const auto & [oldInput, oldBranchArgument] : oldGammaNode.GetEntryVars ())
238- {
239- if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*oldInput->origin ()))
240- {
241- auto oldLoopVar = oldThetaNode.MapPreLoopVar (*oldInput->origin ());
242- repetitionSubregionMap.insert (
243- oldBranchArgument[repetitionSubregionIndex],
244- newLoopVars[oldLoopVar.input ].pre );
245- }
246- else
247- {
248- auto [_, newBranchArgument] =
249- newGammaNode.AddEntryVar (&substitutionMap.lookup (*oldInput->origin ()));
250- auto newLoopVar = newThetaNode->AddLoopVar (newBranchArgument[repetitionSubregionIndex]);
251- repetitionSubregionMap.insert (oldBranchArgument[repetitionSubregionIndex], newLoopVar.pre );
252- newLoopVars[oldInput] = newLoopVar;
253- }
254- }
255-
256- // Copy repetition region
257- oldRepetitionSubregion->copy (newThetaNode->subregion (), repetitionSubregionMap);
258-
259- // Adjust values in substitution map for condition node copying
260- for (const auto & oldLopVar : oldThetaNode.GetLoopVars ())
261- {
262- auto output = oldLopVar.post ->origin ();
263- auto substitute =
264- &repetitionSubregionMap.lookup (*oldRepetitionSubregion->result (output->index ())->origin ());
265- repetitionSubregionMap.insert (oldLopVar.pre , substitute);
266- }
267-
268- // Copy condition nodes
269- CopyPredicateNodes (*newThetaNode->subregion (), repetitionSubregionMap, predicateNodes);
270- auto predicate = &repetitionSubregionMap.lookup (*oldGammaNode.predicate ()->origin ());
271-
272- // Redirect results of loop variables and adjust substitution map for exit region copying
273- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
274- {
275- auto output = oldLoopVar.post ->origin ();
276- auto substitute =
277- &repetitionSubregionMap.lookup (*oldRepetitionSubregion->result (output->index ())->origin ());
278- newLoopVars[oldLoopVar.input ].post ->divert_to (substitute);
279- repetitionSubregionMap.insert (oldLoopVar.post ->origin (), newLoopVars[oldLoopVar.input ].output );
280- }
281- for (const auto & [input, branchArgument] : oldGammaNode.GetEntryVars ())
282- {
283- if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*input->origin ()))
284- {
285- auto oldLoopVar = oldThetaNode.MapPreLoopVar (*input->origin ());
286- repetitionSubregionMap.insert (
287- branchArgument[exitSubregionIndex],
288- newLoopVars[oldLoopVar.input ].output );
289- }
290- else
291- {
292- auto substitute = &repetitionSubregionMap.lookup (*input->origin ());
293- newLoopVars[input].post ->divert_to (substitute);
294- repetitionSubregionMap.insert (branchArgument[exitSubregionIndex], newLoopVars[input].output );
295- }
296- }
297-
298- newThetaNode->set_predicate (predicate);
299-
300- // Copy exit region
301- oldExitSubregion->copy (newRepetitionSubregion, repetitionSubregionMap);
302-
303- // Adjust values in substitution map for exit variable creation
304- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
305- {
306- auto output = oldLoopVar.post ->origin ();
307- auto substitute =
308- &repetitionSubregionMap.lookup (*oldExitSubregion->result (output->index ())->origin ());
309- repetitionSubregionMap.insert (oldLoopVar.post ->origin (), substitute);
310- }
311-
312- return repetitionSubregionMap;
313- }
314-
315155bool
316156LoopUnswitching::allLoopVarsAreRoutedThroughGamma (
317157 const rvsdg::ThetaNode & thetaNode,
@@ -336,6 +176,10 @@ LoopUnswitching::UnswitchLoop(rvsdg::ThetaNode & oldThetaNode)
336176
337177 SinkNodesIntoGamma (*oldGammaNode, oldThetaNode);
338178
179+ // At this point, we have established the following invariant:
180+ // All loop variables of the original theta node are routed through its contained gamma node,
181+ // i.e., the origin of a loop variables' post value must be the output of the gamma node. This
182+ // helps to simplify the transformation significantly.
339183 JLM_ASSERT (allLoopVarsAreRoutedThroughGamma (oldThetaNode, *oldGammaNode));
340184
341185 // FIXME: We should get this correlation from the IsUnswitchable() method, if it is possible
@@ -344,35 +188,134 @@ LoopUnswitching::UnswitchLoop(rvsdg::ThetaNode & oldThetaNode)
344188 JLM_ASSERT (correlationOpt.has_value ());
345189 auto & correlation = correlationOpt.value ();
346190
347- // Copy condition nodes for new gamma node
348- rvsdg::SubstitutionMap substitutionMap;
349- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
350- substitutionMap.insert (oldLoopVar.pre , oldLoopVar.input ->origin ());
191+ // The rest of the transformation is now performed in several stages:
192+ // 1. Copy the predicate subgraph into the parent region of the old theta node.
193+ // 2. Create gamma node (using the copied predicate) and the new theta node in the new gamma
194+ // nodes' repetition subregion.
195+ // 3. Copy old repetition subregion into the new theta node.
196+ // 4. Copy predicate subgraph into new theta node.
197+ // 5. Adjust the loop variables of the new theta node, finalizing the new theta node.
198+ // 6. Add exit variables to the new gamma node, finalizing the new gamma node.
199+ // 7. Copy old exit subregion into the parent region of the old theta node.
200+ // 8. Divert the users of old theta nodes' loop variables, rendering the old theta
201+ // node dead.
202+ //
203+ // Along the way, we keep track of replaced variables with substitution maps for some of the
204+ // stages. These substitution maps are then utilized by succeeding stages to find the correct
205+ // replacements for old outputs.
206+
207+ // Stage 1 - Copy predicate subgraph into the old theta nodes' parent region
208+ rvsdg::SubstitutionMap stage1SMap;
209+ {
210+ for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
211+ stage1SMap.insert (oldLoopVar.pre , oldLoopVar.input ->origin ());
351212
352- auto conditionNodes = CollectPredicateNodes (oldThetaNode, *oldGammaNode);
353- CopyPredicateNodes (*oldThetaNode.region (), substitutionMap, conditionNodes);
213+ auto conditionNodes = CollectPredicateNodes (oldThetaNode, *oldGammaNode);
214+ CopyPredicateNodes (*oldThetaNode.region (), stage1SMap, conditionNodes);
215+ }
354216
217+ // Stage 2 - Create new gamma and theta node
355218 auto newGammaNode = rvsdg::GammaNode::create (
356- &substitutionMap .lookup (*oldGammaNode->predicate ()->origin ()),
219+ &stage1SMap .lookup (*oldGammaNode->predicate ()->origin ()),
357220 oldGammaNode->nsubregions ());
358221
359- auto exitSubregionMap = handleGammaExitRegion (*correlation, *newGammaNode, substitutionMap);
222+ const auto [oldRepetitionSubregion, oldExitSubregion] =
223+ determineGammaSubregionRoles (*correlation).value ();
224+ const auto & newRepetitionSubregion = newGammaNode->subregion (oldRepetitionSubregion->index ());
225+ const auto repetitionSubregionIndex = oldRepetitionSubregion->index ();
226+ const auto exitSubregionIndex = oldExitSubregion->index ();
360227
361- auto repetitionSubstitutionMap =
362- handleGammaRepetitionRegion (*correlation, *newGammaNode, conditionNodes, substitutionMap);
228+ auto newThetaNode = rvsdg::ThetaNode::create (newRepetitionSubregion);
363229
364- // Add exit variables to new gamma
365- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
230+ std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewGammaInputMap;
231+ std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewThetaInputMap;
232+ for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars ())
366233 {
367- auto o0 = &exitSubregionMap.lookup (*oldLoopVar.post ->origin ());
368- auto o1 = &repetitionSubstitutionMap.lookup (*oldLoopVar.post ->origin ());
369- auto [_, output] = newGammaNode->AddExitVar ({ o0, o1 });
370- substitutionMap.insert (oldLoopVar.output , output);
234+ auto & newOrigin = stage1SMap.lookup (*oldInput->origin ());
235+ auto newEntryVar = newGammaNode->AddEntryVar (&newOrigin);
236+ auto newLoopVar =
237+ newThetaNode->AddLoopVar (newEntryVar.branchArgument [repetitionSubregionIndex]);
238+ oldGammaNewGammaInputMap[oldInput] = newEntryVar.input ;
239+ oldGammaNewThetaInputMap[oldInput] = newLoopVar.input ;
240+ }
241+
242+ // Stage 3 - Copy repetition subregion into new theta node
243+ rvsdg::SubstitutionMap stage3SMap;
244+ {
245+ for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars ())
246+ {
247+ auto newLoopInput = oldGammaNewThetaInputMap[oldInput];
248+ auto newLoopVar = newThetaNode->MapInputLoopVar (*newLoopInput);
249+ stage3SMap.insert (oldBranchArgument[repetitionSubregionIndex], newLoopVar.pre );
250+ }
251+
252+ oldRepetitionSubregion->copy (newThetaNode->subregion (), stage3SMap);
371253 }
372254
373- // Replace outputs
374- for (const auto & oldLoopVar : oldThetaNode.GetLoopVars ())
375- oldLoopVar.output ->divert_users (&substitutionMap.lookup (*oldLoopVar.output ));
255+ // Stage 4 - Copy predicate subgraph into new theta node subregion
256+ rvsdg::SubstitutionMap stage4SMap;
257+ {
258+ for (auto oldLoopVar : oldThetaNode.GetLoopVars ())
259+ {
260+ auto oldExitVar = oldGammaNode->MapOutputExitVar (*oldLoopVar.post ->origin ());
261+ auto oldOrigin = oldExitVar.branchResult [repetitionSubregionIndex]->origin ();
262+ auto & newOrigin = stage3SMap.lookup (*oldOrigin);
263+ stage4SMap.insert (oldLoopVar.pre , &newOrigin);
264+ }
265+
266+ auto conditionNodes = CollectPredicateNodes (oldThetaNode, *oldGammaNode);
267+ CopyPredicateNodes (*newThetaNode->subregion (), stage4SMap, conditionNodes);
268+ }
269+
270+ // Stage 5 - Adjust loop variables
271+ newThetaNode->set_predicate (&stage4SMap.lookup (*oldThetaNode.predicate ()->origin ()));
272+ for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars ())
273+ {
274+ auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
275+ auto newLoopVar = newThetaNode->MapInputLoopVar (*newLoopVarInput);
276+ auto & newOrigin = stage4SMap.lookup (*oldInput->origin ());
277+ newLoopVar.post ->divert_to (&newOrigin);
278+ }
279+
280+ // Stage 6 - Add new gamma exit variables
281+ std::unordered_map<rvsdg::Input *, rvsdg::Output *> oldGammaNewGammaOutputMap;
282+ {
283+ for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars ())
284+ {
285+ auto newGammaInput = oldGammaNewGammaInputMap[oldInput];
286+ auto newEntryVar =
287+ std::get<rvsdg::GammaNode::EntryVar>(newGammaNode->MapInput (*newGammaInput));
288+ auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
289+ auto newLoopVar = newThetaNode->MapInputLoopVar (*newLoopVarInput);
290+
291+ std::vector<rvsdg::Output *> values (2 );
292+ values[exitSubregionIndex] = newEntryVar.branchArgument [exitSubregionIndex];
293+ values[repetitionSubregionIndex] = newLoopVar.output ;
294+ auto newExitVar = newGammaNode->AddExitVar (values);
295+ oldGammaNewGammaOutputMap[oldInput] = newExitVar.output ;
296+ }
297+ }
298+
299+ // Stage 7 - Copy exit subregion into old theta node parent region
300+ rvsdg::SubstitutionMap stage7SMap;
301+ {
302+ for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars ())
303+ {
304+ auto newOrigin = oldGammaNewGammaOutputMap[oldInput];
305+ stage7SMap.insert (oldBranchArgument[exitSubregionIndex], newOrigin);
306+ }
307+
308+ oldExitSubregion->copy (oldThetaNode.region (), stage7SMap);
309+ }
310+
311+ // Stage 8 - Replace old theta node outputs
312+ for (auto oldLoopVar : oldThetaNode.GetLoopVars ())
313+ {
314+ auto oldExitVar = oldGammaNode->MapOutputExitVar (*oldLoopVar.post ->origin ());
315+ auto oldOrigin = oldExitVar.branchResult [exitSubregionIndex]->origin ();
316+ auto & newOrigin = stage7SMap.lookup (*oldOrigin);
317+ oldLoopVar.output ->divert_users (&newOrigin);
318+ }
376319
377320 return true ;
378321}
0 commit comments