Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions xla/hlo/analysis/hlo_reachability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,22 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
return result;
}

void HloReachabilityMap::UpdateReachabilityThroughInstruction(
const HloInstruction* instruction) {
void HloReachabilityMap::UpdateReachabilityThroughInstructions(
absl::Span<const HloInstruction* const> instructions) {
std::queue<const HloInstruction*> worklist;
worklist.push(instruction);

std::vector<HloInstruction*> inputs;

// Keep track of the number of times an instruction is in the worklist and
// only process it only if it is the last occurrence. Note that this might
// still mean that an instruction is processed multiple times.
absl::flat_hash_map<const HloInstruction*, int64_t> in_worklist;

for (const HloInstruction* instruction : instructions) {
worklist.push(instruction);
++in_worklist[instruction];
}

std::vector<HloInstruction*> inputs;

while (!worklist.empty()) {
const HloInstruction* item = worklist.front();
worklist.pop();
Expand Down
9 changes: 8 additions & 1 deletion xla/hlo/analysis/hlo_reachability.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,16 @@ class HloReachabilityMap {
}
void SetReachable(Index a, Index b) { BitSetFromIndex(b).Set(a); }

// Updates the given reachability map after the immediate predecessor set
// (operands and control predecessors) of a set of instructions has changed.
void UpdateReachabilityThroughInstructions(
absl::Span<const HloInstruction* const> instructions);

// Updates the given reachability map after the immediate predecessor set
// (operands and control predecessors) of 'instruction' has changed.
void UpdateReachabilityThroughInstruction(const HloInstruction* instruction);
void UpdateReachabilityThroughInstruction(const HloInstruction* instruction) {
UpdateReachabilityThroughInstructions({instruction});
}

// Returns true if "b" is reachable from "a"
//
Expand Down
7 changes: 6 additions & 1 deletion xla/service/copy_removal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ bool ComputeRelativeLocation::AddControlDependenceForUnorderedOps() {
for (const auto& comp_it : ctrl_deps_) {
HloComputation* parent = comp_it.first;
HloReachabilityMap& reachability_map = ordering->reachability_map(parent);
std::vector<HloInstruction*> entries_to_update;
for (const auto& instr_it : comp_it.second) {
HloInstruction* entry1 = instr_it.first;
for (HloInstruction* entry2 : instr_it.second) {
Expand All @@ -321,7 +322,11 @@ bool ComputeRelativeLocation::AddControlDependenceForUnorderedOps() {
VLOG(3) << " successor: " << entry1->name();
CHECK_OK(entry2->AddControlDependencyTo(entry1));
}
reachability_map.UpdateReachabilityThroughInstruction(entry1);
entries_to_update.push_back(entry1);
}
reachability_map.UpdateReachabilityThroughInstructions(entries_to_update);
for (const auto& instr_it : comp_it.second) {
HloInstruction* entry1 = instr_it.first;
for (HloInstruction* entry2 : instr_it.second) {
DCHECK(ordering_->GetExecutionConstraint(entry1, entry2) ==
HloOrdering::ExecutionConstraint::kRunAfter);
Expand Down
Loading