Skip to content

Commit 14fb501

Browse files
authored
LoopUnswitching: Reimplement transformation (phate#1526)
This PR reimplements the loop unswitching transformation. It performs the following changes compared to the previous implementation: 1. Fixes a bug where the wrong origins for variables was picked up. In other words, the old transformation was broken. 2. Avoids the duplication of the exit subregion (the old transformation duplicates it) 3. Adds documentation to avoid confusion in the future. Close phate#895
1 parent 3c74883 commit 14fb501

File tree

5 files changed

+198
-219
lines changed

5 files changed

+198
-219
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
13805
1+
13758
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
13813
1+
13766

jlm/llvm/opt/LoopUnswitching.cpp

Lines changed: 122 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -152,166 +152,6 @@ LoopUnswitching::CopyPredicateNodes(
152152
}
153153
}
154154

155-
rvsdg::SubstitutionMap
156-
LoopUnswitching::handleGammaExitRegion(
157-
const ThetaGammaPredicateCorrelation & correlation,
158-
rvsdg::GammaNode & newGammaNode,
159-
const rvsdg::SubstitutionMap & substitutionMap)
160-
{
161-
const auto & oldThetaNode = correlation.thetaNode();
162-
const auto & oldGammaNode = correlation.gammaNode();
163-
const auto [_, oldExitSubregion] = determineGammaSubregionRoles(correlation).value();
164-
const auto newExitSubregion = newGammaNode.subregion(oldExitSubregion->index());
165-
const auto exitSubregionIndex = oldExitSubregion->index();
166-
167-
// Setup substitution map for exit region copying
168-
rvsdg::SubstitutionMap exitSubregionMap;
169-
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode.GetEntryVars())
170-
{
171-
rvsdg::Output * newGammaInputOrigin = nullptr;
172-
auto & oldGammaInputOrigin = *oldInput->origin();
173-
174-
if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(oldGammaInputOrigin))
175-
{
176-
const auto oldLoopVar = oldThetaNode.MapPreLoopVar(oldGammaInputOrigin);
177-
newGammaInputOrigin = oldLoopVar.input->origin();
178-
}
179-
else if (std::holds_alternative<rvsdg::Node *>(oldGammaInputOrigin.GetOwner()))
180-
{
181-
// The origin is from one of the predicate nodes
182-
const auto substitute = &substitutionMap.lookup(oldGammaInputOrigin);
183-
newGammaInputOrigin = substitute;
184-
}
185-
else
186-
{
187-
throw std::logic_error("This should have never happened!");
188-
}
189-
190-
auto [_, newBranchArgument] = newGammaNode.AddEntryVar(newGammaInputOrigin);
191-
exitSubregionMap.insert(
192-
oldBranchArgument[exitSubregionIndex],
193-
newBranchArgument[exitSubregionIndex]);
194-
}
195-
196-
oldExitSubregion->copy(newExitSubregion, exitSubregionMap);
197-
198-
// Update substitution map for insertion of exit variables
199-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
200-
{
201-
const auto output = oldLoopVar.post->origin();
202-
auto [oldBranchResult, _] = oldGammaNode.MapOutputExitVar(*output);
203-
const auto substitute =
204-
&exitSubregionMap.lookup(*oldBranchResult[exitSubregionIndex]->origin());
205-
exitSubregionMap.insert(output, substitute);
206-
}
207-
208-
return exitSubregionMap;
209-
}
210-
211-
rvsdg::SubstitutionMap
212-
LoopUnswitching::handleGammaRepetitionRegion(
213-
const ThetaGammaPredicateCorrelation & correlation,
214-
rvsdg::GammaNode & newGammaNode,
215-
const std::vector<std::vector<rvsdg::Node *>> & predicateNodes,
216-
const rvsdg::SubstitutionMap & substitutionMap)
217-
{
218-
const auto & oldThetaNode = correlation.thetaNode();
219-
const auto & oldGammaNode = correlation.gammaNode();
220-
const auto [oldRepetitionSubregion, oldExitSubregion] =
221-
determineGammaSubregionRoles(correlation).value();
222-
const auto & newRepetitionSubregion = newGammaNode.subregion(oldRepetitionSubregion->index());
223-
const auto repetitionSubregionIndex = oldRepetitionSubregion->index();
224-
const auto exitSubregionIndex = oldExitSubregion->index();
225-
const auto newThetaNode = rvsdg::ThetaNode::create(newRepetitionSubregion);
226-
227-
// Add loop variables to new theta node and setup substitution map
228-
rvsdg::SubstitutionMap repetitionSubregionMap;
229-
std::unordered_map<rvsdg::Input *, rvsdg::ThetaNode::LoopVar> newLoopVars;
230-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
231-
{
232-
auto [_, branchArgument] = newGammaNode.AddEntryVar(oldLoopVar.input->origin());
233-
auto newLoopVar = newThetaNode->AddLoopVar(branchArgument[repetitionSubregionIndex]);
234-
repetitionSubregionMap.insert(oldLoopVar.pre, newLoopVar.pre);
235-
newLoopVars[oldLoopVar.input] = newLoopVar;
236-
}
237-
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode.GetEntryVars())
238-
{
239-
if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*oldInput->origin()))
240-
{
241-
auto oldLoopVar = oldThetaNode.MapPreLoopVar(*oldInput->origin());
242-
repetitionSubregionMap.insert(
243-
oldBranchArgument[repetitionSubregionIndex],
244-
newLoopVars[oldLoopVar.input].pre);
245-
}
246-
else
247-
{
248-
auto [_, newBranchArgument] =
249-
newGammaNode.AddEntryVar(&substitutionMap.lookup(*oldInput->origin()));
250-
auto newLoopVar = newThetaNode->AddLoopVar(newBranchArgument[repetitionSubregionIndex]);
251-
repetitionSubregionMap.insert(oldBranchArgument[repetitionSubregionIndex], newLoopVar.pre);
252-
newLoopVars[oldInput] = newLoopVar;
253-
}
254-
}
255-
256-
// Copy repetition region
257-
oldRepetitionSubregion->copy(newThetaNode->subregion(), repetitionSubregionMap);
258-
259-
// Adjust values in substitution map for condition node copying
260-
for (const auto & oldLopVar : oldThetaNode.GetLoopVars())
261-
{
262-
auto output = oldLopVar.post->origin();
263-
auto substitute =
264-
&repetitionSubregionMap.lookup(*oldRepetitionSubregion->result(output->index())->origin());
265-
repetitionSubregionMap.insert(oldLopVar.pre, substitute);
266-
}
267-
268-
// Copy condition nodes
269-
CopyPredicateNodes(*newThetaNode->subregion(), repetitionSubregionMap, predicateNodes);
270-
auto predicate = &repetitionSubregionMap.lookup(*oldGammaNode.predicate()->origin());
271-
272-
// Redirect results of loop variables and adjust substitution map for exit region copying
273-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
274-
{
275-
auto output = oldLoopVar.post->origin();
276-
auto substitute =
277-
&repetitionSubregionMap.lookup(*oldRepetitionSubregion->result(output->index())->origin());
278-
newLoopVars[oldLoopVar.input].post->divert_to(substitute);
279-
repetitionSubregionMap.insert(oldLoopVar.post->origin(), newLoopVars[oldLoopVar.input].output);
280-
}
281-
for (const auto & [input, branchArgument] : oldGammaNode.GetEntryVars())
282-
{
283-
if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*input->origin()))
284-
{
285-
auto oldLoopVar = oldThetaNode.MapPreLoopVar(*input->origin());
286-
repetitionSubregionMap.insert(
287-
branchArgument[exitSubregionIndex],
288-
newLoopVars[oldLoopVar.input].output);
289-
}
290-
else
291-
{
292-
auto substitute = &repetitionSubregionMap.lookup(*input->origin());
293-
newLoopVars[input].post->divert_to(substitute);
294-
repetitionSubregionMap.insert(branchArgument[exitSubregionIndex], newLoopVars[input].output);
295-
}
296-
}
297-
298-
newThetaNode->set_predicate(predicate);
299-
300-
// Copy exit region
301-
oldExitSubregion->copy(newRepetitionSubregion, repetitionSubregionMap);
302-
303-
// Adjust values in substitution map for exit variable creation
304-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
305-
{
306-
auto output = oldLoopVar.post->origin();
307-
auto substitute =
308-
&repetitionSubregionMap.lookup(*oldExitSubregion->result(output->index())->origin());
309-
repetitionSubregionMap.insert(oldLoopVar.post->origin(), substitute);
310-
}
311-
312-
return repetitionSubregionMap;
313-
}
314-
315155
bool
316156
LoopUnswitching::allLoopVarsAreRoutedThroughGamma(
317157
const rvsdg::ThetaNode & thetaNode,
@@ -336,6 +176,10 @@ LoopUnswitching::UnswitchLoop(rvsdg::ThetaNode & oldThetaNode)
336176

337177
SinkNodesIntoGamma(*oldGammaNode, oldThetaNode);
338178

179+
// At this point, we have established the following invariant:
180+
// All loop variables of the original theta node are routed through its contained gamma node,
181+
// i.e., the origin of a loop variables' post value must be the output of the gamma node. This
182+
// helps to simplify the transformation significantly.
339183
JLM_ASSERT(allLoopVarsAreRoutedThroughGamma(oldThetaNode, *oldGammaNode));
340184

341185
// FIXME: We should get this correlation from the IsUnswitchable() method, if it is possible
@@ -344,35 +188,134 @@ LoopUnswitching::UnswitchLoop(rvsdg::ThetaNode & oldThetaNode)
344188
JLM_ASSERT(correlationOpt.has_value());
345189
auto & correlation = correlationOpt.value();
346190

347-
// Copy condition nodes for new gamma node
348-
rvsdg::SubstitutionMap substitutionMap;
349-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
350-
substitutionMap.insert(oldLoopVar.pre, oldLoopVar.input->origin());
191+
// The rest of the transformation is now performed in several stages:
192+
// 1. Copy the predicate subgraph into the parent region of the old theta node.
193+
// 2. Create gamma node (using the copied predicate) and the new theta node in the new gamma
194+
// nodes' repetition subregion.
195+
// 3. Copy old repetition subregion into the new theta node.
196+
// 4. Copy predicate subgraph into new theta node.
197+
// 5. Adjust the loop variables of the new theta node, finalizing the new theta node.
198+
// 6. Add exit variables to the new gamma node, finalizing the new gamma node.
199+
// 7. Copy old exit subregion into the parent region of the old theta node.
200+
// 8. Divert the users of old theta nodes' loop variables, rendering the old theta
201+
// node dead.
202+
//
203+
// Along the way, we keep track of replaced variables with substitution maps for some of the
204+
// stages. These substitution maps are then utilized by succeeding stages to find the correct
205+
// replacements for old outputs.
206+
207+
// Stage 1 - Copy predicate subgraph into the old theta nodes' parent region
208+
rvsdg::SubstitutionMap stage1SMap;
209+
{
210+
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
211+
stage1SMap.insert(oldLoopVar.pre, oldLoopVar.input->origin());
351212

352-
auto conditionNodes = CollectPredicateNodes(oldThetaNode, *oldGammaNode);
353-
CopyPredicateNodes(*oldThetaNode.region(), substitutionMap, conditionNodes);
213+
auto conditionNodes = CollectPredicateNodes(oldThetaNode, *oldGammaNode);
214+
CopyPredicateNodes(*oldThetaNode.region(), stage1SMap, conditionNodes);
215+
}
354216

217+
// Stage 2 - Create new gamma and theta node
355218
auto newGammaNode = rvsdg::GammaNode::create(
356-
&substitutionMap.lookup(*oldGammaNode->predicate()->origin()),
219+
&stage1SMap.lookup(*oldGammaNode->predicate()->origin()),
357220
oldGammaNode->nsubregions());
358221

359-
auto exitSubregionMap = handleGammaExitRegion(*correlation, *newGammaNode, substitutionMap);
222+
const auto [oldRepetitionSubregion, oldExitSubregion] =
223+
determineGammaSubregionRoles(*correlation).value();
224+
const auto & newRepetitionSubregion = newGammaNode->subregion(oldRepetitionSubregion->index());
225+
const auto repetitionSubregionIndex = oldRepetitionSubregion->index();
226+
const auto exitSubregionIndex = oldExitSubregion->index();
360227

361-
auto repetitionSubstitutionMap =
362-
handleGammaRepetitionRegion(*correlation, *newGammaNode, conditionNodes, substitutionMap);
228+
auto newThetaNode = rvsdg::ThetaNode::create(newRepetitionSubregion);
363229

364-
// Add exit variables to new gamma
365-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
230+
std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewGammaInputMap;
231+
std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewThetaInputMap;
232+
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
366233
{
367-
auto o0 = &exitSubregionMap.lookup(*oldLoopVar.post->origin());
368-
auto o1 = &repetitionSubstitutionMap.lookup(*oldLoopVar.post->origin());
369-
auto [_, output] = newGammaNode->AddExitVar({ o0, o1 });
370-
substitutionMap.insert(oldLoopVar.output, output);
234+
auto & newOrigin = stage1SMap.lookup(*oldInput->origin());
235+
auto newEntryVar = newGammaNode->AddEntryVar(&newOrigin);
236+
auto newLoopVar =
237+
newThetaNode->AddLoopVar(newEntryVar.branchArgument[repetitionSubregionIndex]);
238+
oldGammaNewGammaInputMap[oldInput] = newEntryVar.input;
239+
oldGammaNewThetaInputMap[oldInput] = newLoopVar.input;
240+
}
241+
242+
// Stage 3 - Copy repetition subregion into new theta node
243+
rvsdg::SubstitutionMap stage3SMap;
244+
{
245+
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
246+
{
247+
auto newLoopInput = oldGammaNewThetaInputMap[oldInput];
248+
auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopInput);
249+
stage3SMap.insert(oldBranchArgument[repetitionSubregionIndex], newLoopVar.pre);
250+
}
251+
252+
oldRepetitionSubregion->copy(newThetaNode->subregion(), stage3SMap);
371253
}
372254

373-
// Replace outputs
374-
for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
375-
oldLoopVar.output->divert_users(&substitutionMap.lookup(*oldLoopVar.output));
255+
// Stage 4 - Copy predicate subgraph into new theta node subregion
256+
rvsdg::SubstitutionMap stage4SMap;
257+
{
258+
for (auto oldLoopVar : oldThetaNode.GetLoopVars())
259+
{
260+
auto oldExitVar = oldGammaNode->MapOutputExitVar(*oldLoopVar.post->origin());
261+
auto oldOrigin = oldExitVar.branchResult[repetitionSubregionIndex]->origin();
262+
auto & newOrigin = stage3SMap.lookup(*oldOrigin);
263+
stage4SMap.insert(oldLoopVar.pre, &newOrigin);
264+
}
265+
266+
auto conditionNodes = CollectPredicateNodes(oldThetaNode, *oldGammaNode);
267+
CopyPredicateNodes(*newThetaNode->subregion(), stage4SMap, conditionNodes);
268+
}
269+
270+
// Stage 5 - Adjust loop variables
271+
newThetaNode->set_predicate(&stage4SMap.lookup(*oldThetaNode.predicate()->origin()));
272+
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
273+
{
274+
auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
275+
auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopVarInput);
276+
auto & newOrigin = stage4SMap.lookup(*oldInput->origin());
277+
newLoopVar.post->divert_to(&newOrigin);
278+
}
279+
280+
// Stage 6 - Add new gamma exit variables
281+
std::unordered_map<rvsdg::Input *, rvsdg::Output *> oldGammaNewGammaOutputMap;
282+
{
283+
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
284+
{
285+
auto newGammaInput = oldGammaNewGammaInputMap[oldInput];
286+
auto newEntryVar =
287+
std::get<rvsdg::GammaNode::EntryVar>(newGammaNode->MapInput(*newGammaInput));
288+
auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
289+
auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopVarInput);
290+
291+
std::vector<rvsdg::Output *> values(2);
292+
values[exitSubregionIndex] = newEntryVar.branchArgument[exitSubregionIndex];
293+
values[repetitionSubregionIndex] = newLoopVar.output;
294+
auto newExitVar = newGammaNode->AddExitVar(values);
295+
oldGammaNewGammaOutputMap[oldInput] = newExitVar.output;
296+
}
297+
}
298+
299+
// Stage 7 - Copy exit subregion into old theta node parent region
300+
rvsdg::SubstitutionMap stage7SMap;
301+
{
302+
for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
303+
{
304+
auto newOrigin = oldGammaNewGammaOutputMap[oldInput];
305+
stage7SMap.insert(oldBranchArgument[exitSubregionIndex], newOrigin);
306+
}
307+
308+
oldExitSubregion->copy(oldThetaNode.region(), stage7SMap);
309+
}
310+
311+
// Stage 8 - Replace old theta node outputs
312+
for (auto oldLoopVar : oldThetaNode.GetLoopVars())
313+
{
314+
auto oldExitVar = oldGammaNode->MapOutputExitVar(*oldLoopVar.post->origin());
315+
auto oldOrigin = oldExitVar.branchResult[exitSubregionIndex]->origin();
316+
auto & newOrigin = stage7SMap.lookup(*oldOrigin);
317+
oldLoopVar.output->divert_users(&newOrigin);
318+
}
376319

377320
return true;
378321
}

0 commit comments

Comments
 (0)