Skip to content

Commit 6bf5925

Browse files
Kingston MandisodzaGoogle-ML-Automation
authored andcommitted
Add methods for checking coexistence, overlap, and prefix without overlap for AxisRef.
This change introduces: - `AxisRef::CanCoexist`: Determines if two AxisRefs can be part of the same sharding. - `AxisRef::Overlaps`: Checks if two AxisRefs have an overlapping range. - `AxisRef::GetPrefixWithoutOverlap`: Returns the portion of an AxisRef that does not overlap with another. - `SortAndMergeAxes`: Sorts and merges a vector of AxisRefs. - `TruncateAxesByRemovingOverlaps`: Truncates a vector of AxisRefs based on overlaps with another set. PiperOrigin-RevId: 872818416
1 parent 964a0a4 commit 6bf5925

File tree

3 files changed

+298
-0
lines changed

3 files changed

+298
-0
lines changed

xla/hlo/ir/mesh_and_axis.cc

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <algorithm>
1919
#include <cassert>
20+
#include <cstddef>
2021
#include <cstdint>
2122
#include <iterator>
2223
#include <memory>
@@ -30,6 +31,7 @@ limitations under the License.
3031
#include "absl/algorithm/container.h"
3132
#include "absl/container/flat_hash_set.h"
3233
#include "absl/log/check.h"
34+
#include "absl/log/log.h"
3335
#include "absl/status/status.h"
3436
#include "absl/strings/numbers.h"
3537
#include "absl/strings/str_cat.h"
@@ -273,6 +275,81 @@ AxisRef::AxisRef(int64_t mesh_axis_index, SubAxis sub_axis_info)
273275
CHECK_GT(sub_axis_info_->size, 1) << "sub-axis size must be > 1";
274276
}
275277

278+
namespace {
279+
280+
bool CanSubAxesCoexist(int64_t min_pre_size, int64_t max_pre_size,
281+
int64_t min_next_pre_size, int64_t max_next_pre_size) {
282+
if (min_next_pre_size > max_pre_size) {
283+
// Sub-axes overlap, check if overlapping and non-overlapping parts are
284+
// valid.
285+
return min_next_pre_size % max_pre_size == 0 &&
286+
max_pre_size % min_pre_size == 0 &&
287+
max_next_pre_size % min_next_pre_size == 0;
288+
}
289+
// Sub-axes don't overlap, check if the gap is valid.
290+
return max_pre_size % min_next_pre_size == 0;
291+
}
292+
293+
} // namespace
294+
295+
bool AxisRef::CanCoexist(const AxisRef& other) const {
296+
if (mesh_axis_index() != other.mesh_axis_index()) {
297+
return true;
298+
}
299+
if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
300+
// One of the axes is full
301+
return true;
302+
}
303+
const SubAxis& this_sub = *sub_axis_info_;
304+
const SubAxis& other_sub = *other.sub_axis_info_;
305+
306+
auto [min_pre_size, max_pre_size] =
307+
std::minmax(this_sub.pre_size, other_sub.pre_size);
308+
int64_t min_next_pre_size =
309+
std::min(this_sub.next_pre_size(), other_sub.next_pre_size());
310+
int64_t max_next_pre_size =
311+
std::max(this_sub.next_pre_size(), other_sub.next_pre_size());
312+
313+
return CanSubAxesCoexist(min_pre_size, max_pre_size, min_next_pre_size,
314+
max_next_pre_size);
315+
}
316+
317+
bool AxisRef::Overlaps(const AxisRef& other) const {
318+
if (mesh_axis_index() != other.mesh_axis_index()) {
319+
return false;
320+
}
321+
if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
322+
// One of the axes is full
323+
return true;
324+
}
325+
const SubAxis& this_sub = *sub_axis_info_;
326+
const SubAxis& other_sub = *other.sub_axis_info_;
327+
328+
return this_sub.pre_size < other_sub.next_pre_size() &&
329+
other_sub.pre_size < this_sub.next_pre_size();
330+
}
331+
332+
std::optional<AxisRef> AxisRef::GetPrefixWithoutOverlap(
333+
const AxisRef& other) const {
334+
if (!CanCoexist(other)) {
335+
return std::nullopt;
336+
}
337+
if (!Overlaps(other)) {
338+
return *this;
339+
}
340+
341+
int64_t this_pre_size =
342+
sub_axis_info_.has_value() ? sub_axis_info_->pre_size : 1;
343+
int64_t other_pre_size =
344+
other.sub_axis_info_.has_value() ? other.sub_axis_info_->pre_size : 1;
345+
346+
if (this_pre_size >= other_pre_size) {
347+
return std::nullopt;
348+
}
349+
return AxisRef(mesh_axis_index_,
350+
SubAxis{this_pre_size, other_pre_size / this_pre_size});
351+
}
352+
276353
bool AxisRef::CanCoexistWithoutOverlap(const AxisRef& other) const {
277354
// Check if the axes are on different mesh dimensions. If so, they can always
278355
// coexist and never overlap.
@@ -414,4 +491,51 @@ absl::Status ValidateSpanOfAxes(absl::Span<const AxisRef> axes,
414491
return absl::OkStatus();
415492
}
416493

494+
void SortAndMergeAxes(std::vector<AxisRef>& axes, const Mesh& mesh) {
495+
if (axes.empty()) {
496+
return;
497+
}
498+
499+
absl::c_sort(axes);
500+
501+
auto current = axes.begin();
502+
for (auto next = current + 1; next != axes.end(); ++next) {
503+
if (current->Overlaps(*next)) {
504+
LOG(FATAL) << "Axes should not overlap: " << current->ToString(&mesh)
505+
<< " and " << next->ToString(&mesh);
506+
}
507+
if (current->CanMerge(*next)) {
508+
CHECK(current->Merge(*next, mesh));
509+
} else {
510+
current++;
511+
*current = *next;
512+
}
513+
}
514+
axes.erase(current + 1, axes.end());
515+
}
516+
517+
bool TruncateAxesByRemovingOverlaps(std::vector<AxisRef>& axes,
518+
absl::Span<const AxisRef> other_axis_refs) {
519+
for (size_t i = 0; i < axes.size(); ++i) {
520+
std::optional<AxisRef> prefix = axes[i];
521+
for (const AxisRef& other : other_axis_refs) {
522+
prefix = prefix->GetPrefixWithoutOverlap(other);
523+
if (!prefix) {
524+
break;
525+
}
526+
}
527+
528+
if (!prefix) {
529+
axes.erase(axes.begin() + i, axes.end());
530+
return true;
531+
}
532+
if (axes[i] != *prefix) {
533+
axes[i] = *prefix;
534+
axes.erase(axes.begin() + i + 1, axes.end());
535+
return true;
536+
}
537+
}
538+
return false;
539+
}
540+
417541
} // namespace xla

xla/hlo/ir/mesh_and_axis.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ class AxisRef {
166166

167167
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
168168

169+
bool operator<(const AxisRef& other) const {
170+
if (mesh_axis_index_ != other.mesh_axis_index_) {
171+
return mesh_axis_index_ < other.mesh_axis_index_;
172+
}
173+
if (!sub_axis_info_.has_value()) {
174+
return true;
175+
}
176+
if (!other.sub_axis_info_.has_value()) {
177+
return false;
178+
}
179+
if (sub_axis_info_->pre_size != other.sub_axis_info_->pre_size) {
180+
return sub_axis_info_->pre_size < other.sub_axis_info_->pre_size;
181+
}
182+
return sub_axis_info_->size < other.sub_axis_info_->size;
183+
}
184+
169185
template <typename H>
170186
friend H AbslHashValue(H h, const AxisRef& a) {
171187
return H::combine(std::move(h), a.mesh_axis_index_, a.sub_axis_info_);
@@ -179,6 +195,29 @@ class AxisRef {
179195

180196
bool CanCoexistWithoutOverlap(const AxisRef& other) const;
181197

198+
// Returns whether this axis and `other` can coexist in the same mesh:
199+
// * If they overlap, then both overlapping and non-overlapping parts must
200+
// be valid axes or sub-axes.
201+
// * Otherwise, both axes can be used to shard the same tensor.
202+
//
203+
// For example:
204+
// "a", "b" -> true
205+
// "a", "b":(2)2 -> true
206+
// "a", "a" -> true
207+
// "a", "a":(2)2 -> true
208+
// "a":(1)2, "a":(4)2 -> true
209+
// "a":(1)4, "a":(2)4 -> true
210+
// "a":(1)2, "a":(1)3 -> false
211+
// "a":(1)2, "a":(3)2 -> false
212+
// "a":(1)3, "a":(2)3 -> false
213+
bool CanCoexist(const AxisRef& other) const;
214+
215+
// Returns true if this axis overlaps with `other`.
216+
bool Overlaps(const AxisRef& other) const;
217+
218+
// Returns the largest prefix of this axis that does not overlap with `other`.
219+
std::optional<AxisRef> GetPrefixWithoutOverlap(const AxisRef& other) const;
220+
182221
// Returns true if this AxisRef can be merged with the `other`, i.e., they are
183222
// consecutive sub-axes of same full axis and this sub-axis is major to other.
184223
bool CanMerge(const AxisRef& other) const;
@@ -209,6 +248,19 @@ absl::Status ValidateSpanOfAxes(absl::Span<const AxisRef> axes,
209248
const Mesh& mesh,
210249
bool allow_mergeable_neighbors = false);
211250

251+
// Sorts and merges the axes in `axes`.
252+
//
253+
// The axes are sorted by `operator<` (mesh axis index, then pre-size).
254+
// Adjacent axes that overlap will cause a fatal error.
255+
// Adjacent axes that can be merged are merged.
256+
void SortAndMergeAxes(std::vector<AxisRef>& axes, const Mesh& mesh);
257+
258+
// Removes parts of `axes` that overlap with any axis in `other_axis_refs`.
259+
//
260+
// Returns true if `axes` is modified.
261+
bool TruncateAxesByRemovingOverlaps(std::vector<AxisRef>& axes,
262+
absl::Span<const AxisRef> other_axis_refs);
263+
212264
} // namespace xla
213265

214266
#endif // XLA_HLO_IR_MESH_AND_AXIS_H_

xla/hlo/ir/mesh_and_axis_test.cc

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/hlo/ir/mesh_and_axis.h"
1717

1818
#include <cstdint>
19+
#include <optional>
1920
#include <vector>
2021

2122
#include <gmock/gmock.h>
@@ -432,4 +433,125 @@ TEST(MeshAndAxisTest, AxisRefMerge) {
432433
EXPECT_EQ(axis_ref4, AxisRef(0, {2, 4}));
433434
}
434435

436+
TEST(MeshAndAxisTest, CanCoexist_DifferentAxes) {
437+
EXPECT_TRUE(AxisRef(0).CanCoexist(AxisRef(1)));
438+
EXPECT_TRUE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(1, {1, 2})));
439+
}
440+
441+
TEST(MeshAndAxisTest, CanCoexist_CompatibleSubAxes) {
442+
EXPECT_TRUE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {2, 2})));
443+
EXPECT_TRUE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {4, 2})));
444+
}
445+
446+
TEST(MeshAndAxisTest, CanCoexist_CompatibleOverlappingSubAxes) {
447+
EXPECT_TRUE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {1, 2})));
448+
EXPECT_TRUE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {1, 4})));
449+
}
450+
451+
TEST(MeshAndAxisTest, CanCoexist_IncompatibleSubAxes) {
452+
EXPECT_FALSE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {1, 3})));
453+
EXPECT_FALSE(AxisRef(0, {1, 2}).CanCoexist(AxisRef(0, {3, 2})));
454+
}
455+
456+
TEST(MeshAndAxisTest, Overlaps_DifferentAxes) {
457+
EXPECT_FALSE(AxisRef(0).Overlaps(AxisRef(1)));
458+
}
459+
460+
TEST(MeshAndAxisTest, Overlaps_FullAxes) {
461+
EXPECT_TRUE(AxisRef(0).Overlaps(AxisRef(0)));
462+
EXPECT_TRUE(AxisRef(0).Overlaps(AxisRef(0, {1, 2})));
463+
}
464+
465+
TEST(MeshAndAxisTest, Overlaps_SubAxes_NoOverlap) {
466+
EXPECT_FALSE(AxisRef(0, {1, 2}).Overlaps(AxisRef(0, {2, 2})));
467+
}
468+
469+
TEST(MeshAndAxisTest, Overlaps_SubAxes_Overlap) {
470+
EXPECT_TRUE(AxisRef(0, {1, 2}).Overlaps(AxisRef(0, {1, 2})));
471+
EXPECT_TRUE(AxisRef(0, {1, 2}).Overlaps(AxisRef(0, {1, 4})));
472+
EXPECT_TRUE(AxisRef(0, {1, 4}).Overlaps(AxisRef(0, {2, 2})));
473+
}
474+
475+
TEST(MeshAndAxisTest, GetPrefixWithoutOverlap_NoOverlap) {
476+
EXPECT_EQ(AxisRef(0).GetPrefixWithoutOverlap(AxisRef(1)), AxisRef(0));
477+
}
478+
479+
TEST(MeshAndAxisTest, GetPrefixWithoutOverlap_FullOverlap) {
480+
EXPECT_EQ(AxisRef(0, {1, 4}).GetPrefixWithoutOverlap(AxisRef(0, {1, 2})),
481+
std::nullopt);
482+
}
483+
484+
TEST(MeshAndAxisTest, GetPrefixWithoutOverlap_PartialOverlap) {
485+
EXPECT_EQ(AxisRef(0, {1, 4}).GetPrefixWithoutOverlap(AxisRef(0, {2, 2})),
486+
AxisRef(0, {1, 2}));
487+
}
488+
489+
TEST(MeshAndAxisTest, GetPrefixWithoutOverlap_PartialOverlap_NoPrefix) {
490+
EXPECT_EQ(AxisRef(0, {1, 4}).GetPrefixWithoutOverlap(AxisRef(0, {1, 2})),
491+
std::nullopt);
492+
}
493+
494+
TEST(MeshAndAxisTest, SortAndMergeAxes) {
495+
Mesh mesh({16, 16}, {"x", "y"});
496+
std::vector<AxisRef> axes = {AxisRef(0, {2, 2}), AxisRef(0, {4, 2}),
497+
AxisRef(0, {1, 2}), AxisRef(1, {1, 2}),
498+
AxisRef(1, {4, 2}), AxisRef(1, {2, 2})};
499+
SortAndMergeAxes(axes, mesh);
500+
501+
EXPECT_THAT(axes,
502+
testing::ElementsAre(AxisRef(0, {1, 8}), AxisRef(1, {1, 8})));
503+
}
504+
505+
TEST(MeshAndAxisTest, SortAndMergeAxesFull) {
506+
Mesh mesh({4}, {"x"});
507+
std::vector<AxisRef> axes = {AxisRef(0, {1, 2}), AxisRef(0, {2, 2})};
508+
SortAndMergeAxes(axes, mesh);
509+
510+
EXPECT_THAT(axes, testing::ElementsAre(AxisRef(0)));
511+
}
512+
513+
TEST(MeshAndAxisTest, TruncateAxesByRemovingOverlaps_PartialOverlap) {
514+
std::vector<AxisRef> axes = {AxisRef(0, {1, 4})};
515+
std::vector<AxisRef> other = {AxisRef(0, {2, 2})};
516+
517+
EXPECT_TRUE(TruncateAxesByRemovingOverlaps(axes, other));
518+
EXPECT_THAT(axes, testing::ElementsAre(AxisRef(0, {1, 2})));
519+
}
520+
521+
TEST(MeshAndAxisTest, TruncateAxesByRemovingOverlaps_FullOverlap) {
522+
std::vector<AxisRef> axes = {AxisRef(0, {1, 2})};
523+
std::vector<AxisRef> other = {AxisRef(0, {1, 4})};
524+
525+
EXPECT_TRUE(TruncateAxesByRemovingOverlaps(axes, other));
526+
EXPECT_THAT(axes, testing::IsEmpty());
527+
}
528+
529+
TEST(MeshAndAxisTest, TruncateAxesByRemovingOverlaps_MultipleAxes) {
530+
std::vector<AxisRef> axes = {AxisRef(0, {1, 4}), AxisRef(1)};
531+
std::vector<AxisRef> other = {AxisRef(0, {2, 2})};
532+
533+
EXPECT_TRUE(TruncateAxesByRemovingOverlaps(axes, other));
534+
EXPECT_THAT(axes, testing::ElementsAre(AxisRef(0, {1, 2})));
535+
}
536+
537+
TEST(MeshAndAxisTest, TruncateAxesByRemovingOverlaps_NoOverlap) {
538+
std::vector<AxisRef> axes = {AxisRef(0, {1, 2}), AxisRef(1)};
539+
std::vector<AxisRef> other = {AxisRef(2)};
540+
541+
EXPECT_FALSE(TruncateAxesByRemovingOverlaps(axes, other));
542+
EXPECT_THAT(axes, testing::ElementsAre(AxisRef(0, {1, 2}), AxisRef(1)));
543+
}
544+
545+
TEST(MeshAndAxisTest, OperatorLess) {
546+
// "x" < "x":(1)2
547+
EXPECT_TRUE(AxisRef(0) < AxisRef(0, {1, 2}));
548+
EXPECT_FALSE(AxisRef(0, {1, 2}) < AxisRef(0));
549+
550+
// "x":(1)2 < "x":(2)2
551+
EXPECT_TRUE(AxisRef(0, {1, 2}) < AxisRef(0, {2, 2}));
552+
553+
// Different axes
554+
EXPECT_TRUE(AxisRef(0) < AxisRef(1));
555+
}
556+
435557
} // namespace xla

0 commit comments

Comments
 (0)