Skip to content

Commit bdd6a01

Browse files
EusebioDMGoogle-ML-Automation
authored andcommitted
Introduce CubScratchSizeDevicelessLookup class
PiperOrigin-RevId: 901170433
1 parent bda66ea commit bdd6a01

8 files changed

Lines changed: 577 additions & 8 deletions

File tree

xla/backends/gpu/libraries/cub/BUILD

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
load("//xla:xla.default.bzl", "xla_cc_test")
1617
load("//xla/tsl/platform:build_config.bzl", "tf_proto_library")
18+
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
1719

1820
package(
1921
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -24,3 +26,47 @@ tf_proto_library(
2426
name = "scratch_space_lookup_table_proto",
2527
srcs = ["scratch_space_lookup_table.proto"],
2628
)
29+
30+
cc_library(
31+
name = "cub_scratch_size_deviceless_lookup",
32+
srcs = ["cub_scratch_size_deviceless_lookup.cc"],
33+
hdrs = ["cub_scratch_size_deviceless_lookup.h"],
34+
deps = [
35+
":cub_sort_utils",
36+
":scratch_space_lookup_table_proto_cc",
37+
"//xla/stream_executor:semantic_version",
38+
"@com_google_absl//absl/status",
39+
"@com_google_absl//absl/status:statusor",
40+
"@com_google_absl//absl/strings",
41+
],
42+
)
43+
44+
cc_library(
45+
name = "cub_sort_utils",
46+
hdrs = ["cub_sort_utils.h"],
47+
visibility = ["//xla:internal"],
48+
deps = [],
49+
)
50+
51+
xla_cc_test(
52+
name = "cub_scratch_size_deviceless_lookup_test",
53+
srcs = ["cub_scratch_size_deviceless_lookup_test.cc"],
54+
deps = [
55+
":cub_scratch_size_deviceless_lookup",
56+
":scratch_space_lookup_table_proto_cc",
57+
"//xla/stream_executor:semantic_version",
58+
"//xla/tsl/util/proto:parse_text_proto",
59+
"@com_google_absl//absl/status",
60+
"@com_google_absl//absl/status:status_matchers",
61+
"@com_google_googletest//:gtest_main",
62+
],
63+
)
64+
65+
xla_cc_test(
66+
name = "cub_sort_utils_test",
67+
srcs = ["cub_sort_utils_test.cc"],
68+
deps = [
69+
":cub_sort_utils",
70+
"@com_google_googletest//:gtest_main",
71+
],
72+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/gpu/libraries/cub/cub_scratch_size_deviceless_lookup.h"
17+
18+
#include <algorithm>
19+
#include <cstdint>
20+
#include <optional>
21+
#include <utility>
22+
23+
#include "absl/status/status.h"
24+
#include "absl/status/statusor.h"
25+
#include "absl/strings/string_view.h"
26+
#include "xla/backends/gpu/libraries/cub/cub_sort_utils.h"
27+
#include "xla/stream_executor/semantic_version.h"
28+
29+
namespace xla::gpu {
30+
31+
absl::StatusOr<CubScratchSizeDevicelessLookup>
32+
CubScratchSizeDevicelessLookup::Create(CubScratchSizeLookupTable proto) {
33+
for (const auto& entry : proto.entries()) {
34+
for (int i = 1; i < entry.scratch_size_recordings_size(); ++i) {
35+
if (entry.scratch_size_recordings(i).num_items() <=
36+
entry.scratch_size_recordings(i - 1).num_items()) {
37+
return absl::InvalidArgumentError(
38+
"scratch_size_recordings must be sorted by num_items");
39+
}
40+
}
41+
}
42+
return CubScratchSizeDevicelessLookup(std::move(proto));
43+
}
44+
45+
CubScratchSizeDevicelessLookup::CubScratchSizeDevicelessLookup(
46+
CubScratchSizeLookupTable proto)
47+
: proto_(std::move(proto)) {}
48+
49+
const CubScratchSizeEntry* CubScratchSizeDevicelessLookup::FindEntry(
50+
stream_executor::SemanticVersion cub_version, absl::string_view device_name,
51+
int32_t key_type_size, std::optional<int32_t> value_type_size,
52+
bool is_segmented) const {
53+
for (const CubScratchSizeEntry& entry : proto_.entries()) {
54+
bool version_matched =
55+
std::find(entry.cub_version().begin(), entry.cub_version().end(),
56+
cub_version.ToString()) != entry.cub_version().end();
57+
58+
if (version_matched && entry.device_name() == device_name &&
59+
entry.key_type_size() == key_type_size &&
60+
entry.value_type_size() == value_type_size.value_or(0) &&
61+
entry.is_segmented() == is_segmented) {
62+
return &entry;
63+
}
64+
}
65+
return nullptr;
66+
}
67+
68+
std::optional<int64_t> CubScratchSizeDevicelessLookup::Lookup(
69+
stream_executor::SemanticVersion cub_version, absl::string_view device_name,
70+
int32_t key_type_size, std::optional<int32_t> value_type_size,
71+
int64_t num_items, int64_t batch_size) const {
72+
const CubScratchSizeEntry* entry =
73+
FindEntry(cub_version, device_name, key_type_size, value_type_size,
74+
/*is_segmented=*/batch_size > 1);
75+
if (entry == nullptr) {
76+
return std::nullopt;
77+
}
78+
79+
auto it = std::lower_bound(
80+
entry->scratch_size_recordings().begin(),
81+
entry->scratch_size_recordings().end(), num_items,
82+
[](const CubScratchSizeEntry::ScratchSizeRecord& record,
83+
int64_t num_items) { return record.num_items() < num_items; });
84+
85+
if (it == entry->scratch_size_recordings().end()) {
86+
return std::nullopt;
87+
}
88+
89+
return AddSegmentedSortOffsetsToScratchSize(it->scratch_space_bytes(),
90+
batch_size);
91+
}
92+
93+
bool CubScratchSizeDevicelessLookup::CanLookup(
94+
stream_executor::SemanticVersion cub_version, absl::string_view device_name,
95+
int32_t key_type_size, std::optional<int32_t> value_type_size,
96+
int64_t num_items, int64_t batch_size) const {
97+
const CubScratchSizeEntry* entry = FindEntry(
98+
cub_version, device_name, key_type_size, value_type_size, batch_size > 1);
99+
if (entry == nullptr || entry->scratch_size_recordings().empty()) {
100+
return false;
101+
}
102+
103+
return entry->scratch_size_recordings()
104+
.Get(entry->scratch_size_recordings().size() - 1)
105+
.num_items() >= num_items;
106+
}
107+
108+
} // namespace xla::gpu
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_BACKENDS_GPU_LIBRARIES_CUB_CUB_SCRATCH_SIZE_DEVICELESS_LOOKUP_H_
17+
#define XLA_BACKENDS_GPU_LIBRARIES_CUB_CUB_SCRATCH_SIZE_DEVICELESS_LOOKUP_H_
18+
19+
#include <cstdint>
20+
#include <optional>
21+
22+
#include "absl/status/statusor.h"
23+
#include "absl/strings/string_view.h"
24+
#include "xla/backends/gpu/libraries/cub/scratch_space_lookup_table.pb.h"
25+
#include "xla/stream_executor/semantic_version.h"
26+
27+
namespace xla::gpu {
28+
29+
// A lookup table for that returns an estimate of the scratch space required by
30+
// CUB sorts without requiring a GPU.
31+
class CubScratchSizeDevicelessLookup {
32+
public:
33+
static absl::StatusOr<CubScratchSizeDevicelessLookup> Create(
34+
CubScratchSizeLookupTable proto);
35+
36+
// Looks up the estimated scratch space CUB will need for the given
37+
// parameters. The estimated space will be >= to the actual space CUB will
38+
// need.
39+
//
40+
// Will return std::nullopt if we can't estimate it, i.e., if we have no
41+
// entries for the give parameters, or if the requested num_items is greater
42+
// than any recorded num_items for the given parameters.
43+
std::optional<int64_t> Lookup(stream_executor::SemanticVersion cub_version,
44+
absl::string_view device_name,
45+
int32_t key_type_size,
46+
std::optional<int32_t> value_type_size,
47+
int64_t num_items,
48+
int64_t batch_size = 1) const;
49+
50+
// Cheaper check to see if we have data for the given parameters.
51+
//
52+
// Returns true if a matching entry exists and the requested num_items does
53+
// not exceed the largest recorded num_items in that entry.
54+
bool CanLookup(stream_executor::SemanticVersion cub_version,
55+
absl::string_view device_name, int32_t key_type_size,
56+
std::optional<int32_t> value_type_size, int64_t num_items,
57+
int64_t batch_size = 1) const;
58+
59+
private:
60+
explicit CubScratchSizeDevicelessLookup(CubScratchSizeLookupTable proto);
61+
62+
const CubScratchSizeEntry* FindEntry(
63+
stream_executor::SemanticVersion cub_version,
64+
absl::string_view device_name, int32_t key_type_size,
65+
std::optional<int32_t> value_type_size, bool is_segmented) const;
66+
67+
CubScratchSizeLookupTable proto_;
68+
};
69+
70+
} // namespace xla::gpu
71+
72+
#endif // XLA_BACKENDS_GPU_LIBRARIES_CUB_CUB_SCRATCH_SIZE_DEVICELESS_LOOKUP_H_

0 commit comments

Comments
 (0)