-
Notifications
You must be signed in to change notification settings - Fork 6k
[Cpp API Compatibility] Align device related APIs #78551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ third_party/ | |
| *~ | ||
| bazel-* | ||
| .humanize | ||
| .codex | ||
|
|
||
| build_* | ||
| # clion workspace. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,12 @@ using gpuStream_t = hipStream_t; | |
| #endif | ||
|
|
||
| #include <c10/core/DeviceType.h> | ||
| #include <c10/util/Exception.h> | ||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <iosfwd> | ||
| #include <string> | ||
| #include <utility> | ||
|
|
||
|
|
@@ -42,13 +47,19 @@ struct Device final { | |
| index_(place.GetType() == phi::AllocationType::CPU | ||
| ? static_cast<DeviceIndex>(-1) | ||
| : place.GetDeviceId()), | ||
| custom_device_type_(place.GetDeviceType()) {} | ||
| custom_device_type_(place.GetDeviceType()) { | ||
| validate(); | ||
| } | ||
| Device(DeviceType type, DeviceIndex index = -1) | ||
| : type_(type), index_(index) {} // NOLINT | ||
| : type_(type), index_(index) { // NOLINT | ||
| validate(); | ||
| } | ||
| Device(DeviceType type, DeviceIndex index, std::string custom_device_type) | ||
| : type_(type), | ||
| index_(index), | ||
| custom_device_type_(std::move(custom_device_type)) {} // NOLINT | ||
| custom_device_type_(std::move(custom_device_type)) { // NOLINT | ||
| validate(); | ||
| } | ||
|
|
||
| /// Constructs a `Device` from a string description, for convenience. | ||
| /// The string supplied must follow the following schema: | ||
|
|
@@ -63,10 +74,51 @@ struct Device final { | |
|
|
||
| DeviceType type() const noexcept { return type_; } | ||
|
|
||
| bool operator!=(const Device& other) const noexcept { | ||
| return !(*this == other); | ||
| } | ||
|
|
||
| void set_index(DeviceIndex index) { | ||
| index_ = index; | ||
| validate(); | ||
| } | ||
|
|
||
| bool is_cuda() const noexcept { return type_ == DeviceType::CUDA; } | ||
|
|
||
| bool is_privateuseone() const noexcept { | ||
| return type_ == DeviceType::PrivateUse1; | ||
| } | ||
|
|
||
| bool is_mps() const noexcept { return false; } | ||
|
|
||
| bool is_hip() const noexcept { return false; } | ||
|
|
||
| bool is_ve() const noexcept { return false; } | ||
|
|
||
| bool is_xpu() const noexcept { return type_ == DeviceType::XPU; } | ||
|
|
||
| bool is_ipu() const noexcept { return type_ == DeviceType::IPU; } | ||
|
|
||
| bool is_xla() const noexcept { return false; } | ||
|
|
||
| bool is_mtia() const noexcept { return false; } | ||
|
|
||
| bool is_hpu() const noexcept { return false; } | ||
|
|
||
| bool is_lazy() const noexcept { return false; } | ||
|
|
||
| bool is_vulkan() const noexcept { return false; } | ||
|
|
||
| bool is_metal() const noexcept { return false; } | ||
|
|
||
| bool is_maia() const noexcept { return false; } | ||
|
|
||
| bool is_meta() const noexcept { return false; } | ||
|
|
||
| bool is_cpu() const noexcept { return type_ == DeviceType::CPU; } | ||
|
|
||
| bool supports_as_strided() const noexcept { return type_ != DeviceType::IPU; } | ||
|
|
||
| std::string str() const; | ||
|
|
||
| bool operator==(const Device& other) const noexcept { | ||
|
|
@@ -96,12 +148,37 @@ struct Device final { | |
| DeviceType type_{DeviceType::CPU}; | ||
| DeviceIndex index_{-1}; | ||
| std::string custom_device_type_; | ||
|
|
||
| void validate() { | ||
| #ifndef NDEBUG | ||
| TORCH_INTERNAL_ASSERT(index_ >= -1, | ||
| "Device index must be -1 or non-negative, got ", | ||
| static_cast<int>(index_)); | ||
| TORCH_INTERNAL_ASSERT(!is_cpu() || index_ <= 0, | ||
| "CPU device index must be -1 or zero, got ", | ||
| static_cast<int>(index_)); | ||
| #endif | ||
| } | ||
| }; | ||
|
|
||
| std::ostream& operator<<(std::ostream& stream, const Device& device); | ||
|
|
||
| } // namespace c10 | ||
|
|
||
| namespace std { | ||
| template <> | ||
| struct hash<c10::Device> { | ||
| size_t operator()(c10::Device d) const noexcept { | ||
| static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); | ||
| static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); | ||
| uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type())) | ||
| << 16 | | ||
| static_cast<uint32_t>(static_cast<uint8_t>(d.index())); | ||
| return std::hash<uint32_t>{}(bits); | ||
| } | ||
|
Comment on lines
+168
to
+178
|
||
| }; | ||
| } // namespace std | ||
|
|
||
| namespace at { | ||
| using c10::Device; | ||
| using c10::DeviceIndex; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parse_type()can reach the end of a non-void function without returning aDeviceType(after theTORCH_CHECK(false, ...)). Even ifTORCH_CHECKthrows at runtime, this is still UB in the language rules and may trigger-Wreturn-type/-Werrorbuild failures. Add an explicitreturn(or an unreachable annotation) after the check to satisfy the compiler.