Skip to content

Commit 2b80205

Browse files
mehrdadkhaniGoogle-ML-Automation
authored andcommitted
Make window prefetching in memory space assignment more deterministic.
PiperOrigin-RevId: 872073012
1 parent b27c8af commit 2b80205

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

xla/service/memory_space_assignment/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ cc_library(
618618
"@com_google_absl//absl/container:flat_hash_map",
619619
"@com_google_absl//absl/container:flat_hash_set",
620620
"@com_google_absl//absl/container:inlined_vector",
621+
"@com_google_absl//absl/container:linked_hash_map",
622+
"@com_google_absl//absl/container:linked_hash_set",
621623
"@com_google_absl//absl/functional:any_invocable",
622624
"@com_google_absl//absl/hash",
623625
"@com_google_absl//absl/log",

xla/service/memory_space_assignment/algorithm.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ limitations under the License.
3939
#include "absl/container/flat_hash_map.h"
4040
#include "absl/container/flat_hash_set.h"
4141
#include "absl/container/inlined_vector.h"
42+
#include "absl/container/linked_hash_map.h"
4243
#include "absl/functional/any_invocable.h"
4344
#include "absl/hash/hash.h"
4445
#include "absl/log/check.h"
@@ -7203,8 +7204,10 @@ absl::Status MsaAlgorithm::WindowPrefetch() {
72037204
// cloned computation and use the cloned computation to determine the operand
72047205
// span size.
72057206

7206-
// Map of the original instruction to a clone of the instruction.
7207-
absl::flat_hash_map<HloInstruction*, HloInstruction*> cloned_insts;
7207+
// Map of the original instruction to a clone of the instruction. Use a
7208+
// linked_hash_map to ensure deterministic traversal for memory space
7209+
// propagation and cleanup.
7210+
absl::linked_hash_map<HloInstruction*, HloInstruction*> cloned_insts;
72087211
const std::vector<HloInstruction*>& instruction_sequence =
72097212
hlo_live_range_.flattened_instruction_sequence().instructions();
72107213
for (HloInstruction* instruction : instruction_sequence) {

xla/service/memory_space_assignment/algorithm.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ limitations under the License.
2929
#include <utility>
3030
#include <vector>
3131

32+
#include "absl/container/linked_hash_set.h"
33+
3234
// TODO(b/210891274): Use btree_map after build issue in Windows is resolved.
3335
#if defined(__GNUC__) || defined(__clang__)
3436
#include "absl/container/btree_map.h"
@@ -1362,8 +1364,9 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
13621364
absl::flat_hash_map<HloPosition, std::vector<int64_t>>
13631365
default_memory_coloring_requirements_;
13641366

1365-
// Set of HloUses that are in the default memory.
1366-
absl::flat_hash_set<HloUse> uses_in_default_memory_;
1367+
// Set of HloUses that are in the default memory. Using linked_hash_set for
1368+
// deterministic window prefetching results.
1369+
absl::linked_hash_set<HloUse> uses_in_default_memory_;
13671370
};
13681371

13691372
} // namespace memory_space_assignment

0 commit comments

Comments
 (0)