Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 57 additions & 54 deletions jlm/llvm/opt/alias-analyses/AgnosticModRefSummarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AgnosticModRefSummary final : public ModRefSummary
{
return it->second;
}
if (is<CallOperation>(&node))
if (is<CallOperation>(node.GetOperation()))
{
return AllMemoryNodes_;
}
Expand Down Expand Up @@ -211,59 +211,62 @@ AgnosticModRefSummarizer::AddPointerTargetsToModRefSet(
void
AgnosticModRefSummarizer::AnnotateSimpleNode(const rvsdg::SimpleNode & node)
{
if (is<StoreOperation>(&node))
{
const auto & address = *StoreOperation::AddressInput(node).origin();
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
AddPointerTargetsToModRefSet(address, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
}
else if (is<LoadOperation>(&node))
{
const auto & address = *LoadOperation::AddressInput(node).origin();
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
AddPointerTargetsToModRefSet(address, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
}
else if (is<MemCpyOperation>(&node))
{
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
const auto & srcAddress = *MemCpyOperation::sourceInput(node).origin();
const auto & dstAddress = *MemCpyOperation::destinationInput(node).origin();
AddPointerTargetsToModRefSet(srcAddress, modRefSet);
AddPointerTargetsToModRefSet(dstAddress, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
}
else if (is<FreeOperation>(&node))
{
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
const auto & freeAddress = *FreeOperation::addressInput(node).origin();
AddPointerTargetsToModRefSet(freeAddress, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
}
else if (is<AllocaOperation>(&node))
{
const auto allocaMemoryNode = ModRefSummary_->GetPointsToGraph().getNodeForAlloca(node);
ModRefSummary_->SetSimpleNodeModRef(node, { allocaMemoryNode });
}
else if (is<MallocOperation>(&node))
{
const auto mallocMemoryNode = ModRefSummary_->GetPointsToGraph().getNodeForMalloc(node);
ModRefSummary_->SetSimpleNodeModRef(node, { mallocMemoryNode });
}
else if (is<CallOperation>(&node))
{
// CallOperations are omitted on purpose, as calls use the AllMemoryNodes as their ModRef set.
}
else if (is<MemoryStateOperation>(&node))
{
// Memory state operations are only used to route memory state edges
}
else
{
// Any remaining type of node should not involve any memory states
JLM_ASSERT(!hasMemoryState(node));
}
MatchTypeWithDefault(
node.GetOperation(),
[&](const StoreOperation &)
{
const auto & address = *StoreOperation::AddressInput(node).origin();
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
AddPointerTargetsToModRefSet(address, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
},
[&](const LoadOperation &)
{
const auto & address = *LoadOperation::AddressInput(node).origin();
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
AddPointerTargetsToModRefSet(address, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
},
[&](const MemCpyOperation &)
{
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
const auto & srcAddress = *MemCpyOperation::sourceInput(node).origin();
const auto & dstAddress = *MemCpyOperation::destinationInput(node).origin();
AddPointerTargetsToModRefSet(srcAddress, modRefSet);
AddPointerTargetsToModRefSet(dstAddress, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
},
[&](const FreeOperation &)
{
util::HashSet<PointsToGraph::NodeIndex> modRefSet;
const auto & freeAddress = *FreeOperation::addressInput(node).origin();
AddPointerTargetsToModRefSet(freeAddress, modRefSet);
ModRefSummary_->SetSimpleNodeModRef(node, std::move(modRefSet));
},
[&](const AllocaOperation &)
{
const auto allocaMemoryNode = ModRefSummary_->GetPointsToGraph().getNodeForAlloca(node);
ModRefSummary_->SetSimpleNodeModRef(node, { allocaMemoryNode });
},
[&](const MallocOperation &)
{
const auto mallocMemoryNode = ModRefSummary_->GetPointsToGraph().getNodeForMalloc(node);
ModRefSummary_->SetSimpleNodeModRef(node, { mallocMemoryNode });
},
[&](const CallOperation &)
{
// CallOperations are omitted on purpose, as calls use the AllMemoryNodes as their ModRef
// set.
},
[&](const MemoryStateOperation &)
{
// Memory state operations are only used to route memory state edges
},
[&]()
{
// Any remaining type of node should not involve any memory states
JLM_ASSERT(!hasMemoryState(node));
});
}

std::unique_ptr<ModRefSummary>
Expand Down
70 changes: 43 additions & 27 deletions jlm/llvm/opt/alias-analyses/LocalAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <jlm/llvm/ir/types.hpp>
#include <jlm/rvsdg/gamma.hpp>
#include <jlm/rvsdg/lambda.hpp>
#include <jlm/rvsdg/MatchType.hpp>
#include <jlm/rvsdg/theta.hpp>

#include <numeric>
Expand Down Expand Up @@ -441,10 +442,10 @@ LocalAliasAnalysis::IsOriginalOrigin(const rvsdg::Output & pointer)
// Is pointer the output of one of the nodes
if (const auto node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(pointer))
{
if (is<AllocaOperation>(node))
if (is<AllocaOperation>(node->GetOperation()))
return true;

if (is<MallocOperation>(node))
if (is<MallocOperation>(node->GetOperation()))
return true;
}

Expand Down Expand Up @@ -693,32 +694,47 @@ LocalAliasAnalysis::IsOriginalOriginFullyTraceable(const rvsdg::Output & pointer

if (auto node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
{
// Pointers go straight through IO barriers and GEPs
if (is<IOBarrierOperation>(node) || is<GetElementPtrOperation>(node))
{
// The pointer input must be the node's first input
JLM_ASSERT(user.index() == 0);
Enqueue(*node->output(0));
bool do_continue = MatchTypeWithDefault(
node->GetOperation(),
[&](const IOBarrierOperation &)
{
// The pointer input must be the node's first input
JLM_ASSERT(user.index() == 0);
Enqueue(*node->output(0));
return true;
},
[&](const GetElementPtrOperation &)
{
// The pointer input must be the node's first input
JLM_ASSERT(user.index() == 0);
Enqueue(*node->output(0));
return true;
},
[&](const SelectOperation &)
{
// Select operations are fine, if the output is still fully traceable
Enqueue(*node->output(0));
return true;
},
[&](const LoadOperation &)
{
// Loads are always fine
return true;
},
[&](const StoreOperation &)
{
// Stores are only fine if the pointer itself is not being stored somewhere
if (&user == &StoreOperation::AddressInput(*node))
return true;
else
return false;
},
[]()
{
return false;
});
if (do_continue)
continue;
}

if (is<SelectOperation>(node))
{
// Select operations are fine, if the output is still fully traceable
Enqueue(*node->output(0));
continue;
}

// Loads are always fine
if (is<LoadOperation>(node))
continue;

// Stores are only fine if the pointer itself is not being stored somewhere
if (is<StoreOperation>(node))
{
if (&user == &StoreOperation::AddressInput(*node))
continue;
}
}

// We were unable to handle this user, so the original pointer escapes tracing
Expand Down
94 changes: 48 additions & 46 deletions jlm/llvm/opt/alias-analyses/MemoryStateEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,49 +702,51 @@ MemoryStateEncoder::EncodeStructuralNode(rvsdg::StructuralNode & structuralNode)
void
MemoryStateEncoder::EncodeSimpleNode(const rvsdg::SimpleNode & simpleNode)
{
if (is<AllocaOperation>(&simpleNode))
{
EncodeAlloca(simpleNode);
}
else if (is<MallocOperation>(&simpleNode))
{
EncodeMalloc(simpleNode);
}
else if (is<LoadOperation>(&simpleNode))
{
EncodeLoad(simpleNode);
}
else if (is<StoreOperation>(&simpleNode))
{
EncodeStore(simpleNode);
}
else if (is<CallOperation>(&simpleNode))
{
EncodeCall(simpleNode);
}
else if (is<FreeOperation>(&simpleNode))
{
EncodeFree(simpleNode);
}
else if (is<MemCpyOperation>(&simpleNode))
{
EncodeMemcpy(simpleNode);
}
else if (is<MemoryStateOperation>(&simpleNode))
{
// Nothing needs to be done
}
else
{
// Ensure we took care of all memory state consuming/producing nodes
JLM_ASSERT(!hasMemoryState(simpleNode));
}
MatchTypeWithDefault(
simpleNode.GetOperation(),
[&](const AllocaOperation &)
{
EncodeAlloca(simpleNode);
},
[&](const MallocOperation &)
{
EncodeMalloc(simpleNode);
},
[&](const LoadOperation &)
{
EncodeLoad(simpleNode);
},
[&](const StoreOperation &)
{
EncodeStore(simpleNode);
},
[&](const CallOperation &)
{
EncodeCall(simpleNode);
},
[&](const FreeOperation &)
{
EncodeFree(simpleNode);
},
[&](const MemCpyOperation &)
{
EncodeMemcpy(simpleNode);
},
[&](const MemoryStateOperation &)
{
// Nothing needs to be done
},
[&]()
{
// Ensure we took care of all memory state consuming/producing nodes
JLM_ASSERT(!hasMemoryState(simpleNode));
});
}

void
MemoryStateEncoder::EncodeAlloca(const rvsdg::SimpleNode & allocaNode)
{
JLM_ASSERT(is<AllocaOperation>(&allocaNode));
JLM_ASSERT(is<AllocaOperation>(allocaNode.GetOperation()));

auto & stateMap = Context_->GetRegionalizedStateMap();
auto allocaMemoryNodes = stateMap.GetSimpleNodeModRef(allocaNode);
Expand All @@ -770,7 +772,7 @@ MemoryStateEncoder::EncodeAlloca(const rvsdg::SimpleNode & allocaNode)
void
MemoryStateEncoder::EncodeMalloc(const rvsdg::SimpleNode & mallocNode)
{
JLM_ASSERT(is<MallocOperation>(&mallocNode));
JLM_ASSERT(is<MallocOperation>(mallocNode.GetOperation()));
auto & stateMap = Context_->GetRegionalizedStateMap();
auto mallocMemoryNodes = stateMap.GetSimpleNodeModRef(mallocNode);
JLM_ASSERT(mallocMemoryNodes.Size() == 1);
Expand Down Expand Up @@ -798,7 +800,7 @@ MemoryStateEncoder::EncodeMalloc(const rvsdg::SimpleNode & mallocNode)
void
MemoryStateEncoder::EncodeLoad(const rvsdg::SimpleNode & node)
{
JLM_ASSERT(is<LoadOperation>(&node));
JLM_ASSERT(is<LoadOperation>(node.GetOperation()));
auto & stateMap = Context_->GetRegionalizedStateMap();

const auto & memoryNodes = stateMap.GetSimpleNodeModRef(node);
Expand Down Expand Up @@ -839,7 +841,7 @@ MemoryStateEncoder::EncodeStore(const rvsdg::SimpleNode & node)
void
MemoryStateEncoder::EncodeFree(const rvsdg::SimpleNode & freeNode)
{
JLM_ASSERT(is<FreeOperation>(&freeNode));
JLM_ASSERT(is<FreeOperation>(freeNode.GetOperation()));
auto & stateMap = Context_->GetRegionalizedStateMap();

auto address = freeNode.input(0)->origin();
Expand Down Expand Up @@ -892,7 +894,7 @@ MemoryStateEncoder::EncodeCall(const rvsdg::SimpleNode & callNode)
void
MemoryStateEncoder::EncodeMemcpy(const rvsdg::SimpleNode & memcpyNode)
{
JLM_ASSERT(is<MemCpyOperation>(&memcpyNode));
JLM_ASSERT(is<MemCpyOperation>(memcpyNode.GetOperation()));
auto & stateMap = Context_->GetRegionalizedStateMap();

auto memoryNodeStatePairs = stateMap.GetExistingStates(memcpyNode);
Expand Down Expand Up @@ -1116,7 +1118,7 @@ MemoryStateEncoder::ReplaceLoadNode(
const rvsdg::SimpleNode & node,
const std::vector<rvsdg::Output *> & memoryStates)
{
JLM_ASSERT(is<LoadOperation>(&node));
JLM_ASSERT(is<LoadOperation>(node.GetOperation()));

if (const auto loadVolatileOperation =
dynamic_cast<const LoadVolatileOperation *>(&node.GetOperation()))
Expand Down Expand Up @@ -1191,13 +1193,13 @@ MemoryStateEncoder::ReplaceMemcpyNode(
const rvsdg::SimpleNode & memcpyNode,
const std::vector<rvsdg::Output *> & memoryStates)
{
JLM_ASSERT(is<MemCpyOperation>(&memcpyNode));
JLM_ASSERT(is<MemCpyOperation>(memcpyNode.GetOperation()));

auto destination = memcpyNode.input(0)->origin();
auto source = memcpyNode.input(1)->origin();
auto length = memcpyNode.input(2)->origin();

if (is<MemCpyVolatileOperation>(&memcpyNode))
if (is<MemCpyVolatileOperation>(memcpyNode.GetOperation()))
{
auto & ioState = *memcpyNode.input(3)->origin();
auto & newMemcpyNode =
Expand All @@ -1210,7 +1212,7 @@ MemoryStateEncoder::ReplaceMemcpyNode(
// Skip I/O state and only return memory states
return { std::next(results.begin()), results.end() };
}
if (is<MemCpyNonVolatileOperation>(&memcpyNode))
if (is<MemCpyNonVolatileOperation>(memcpyNode.GetOperation()))
{
return MemCpyNonVolatileOperation::create(destination, source, length, memoryStates);
}
Expand Down
Loading
Loading