Skip to content

Commit 3308e99

Browse files
youge325YuhanXu
authored andcommitted
[Cpp API Compatibility] Align device related APIs (PaddlePaddle#78551)
1 parent d0ffd56 commit 3308e99

File tree

6 files changed

+239
-38
lines changed

6 files changed

+239
-38
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ third_party/
7070
*~
7171
bazel-*
7272
.humanize
73+
.codex
7374

7475
build_*
7576
# clion workspace.

paddle/phi/api/include/compat/c10/core/Device.cpp

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
#include <c10/core/Device.h>
2020
#include <c10/util/Exception.h>
21+
22+
#include <algorithm>
2123
#include <array>
24+
#include <cctype>
25+
#include <exception>
26+
#include <string>
27+
2228
#include "paddle/common/enforce.h"
2329

2430
namespace c10 {
@@ -45,44 +51,98 @@ const char* DeviceTypeToString(DeviceType type) {
4551

4652
DeviceType parse_type(const std::string& device_string) {
4753
static const std::array<std::pair<const char*, DeviceType>,
48-
static_cast<size_t>(4)>
54+
static_cast<size_t>(5)>
4955
types = {{
5056
{"cpu", DeviceType::CPU},
5157
{"cuda", DeviceType::CUDA},
5258
{"ipu", DeviceType::IPU},
5359
{"xpu", DeviceType::XPU},
60+
{"privateuseone", DeviceType::PrivateUse1},
5461
}};
55-
for (const auto& type_pair : types) {
56-
if (device_string == type_pair.first) {
57-
return type_pair.second;
58-
}
62+
auto device = std::find_if(
63+
types.begin(),
64+
types.end(),
65+
[&device_string](const std::pair<const char*, DeviceType>& p) {
66+
return p.first && p.first == device_string;
67+
});
68+
if (device != types.end()) {
69+
return device->second;
5970
}
60-
PADDLE_THROW(::common::errors::InvalidArgument(
61-
"Unknown device type: '%s'. Supported device types are ",
62-
"'cpu', 'cuda', 'ipu', and 'xpu'.",
63-
device_string));
71+
TORCH_CHECK(false,
72+
"Expected one of cpu, cuda, ipu, xpu, privateuseone device type "
73+
"at start of device string: ",
74+
device_string);
6475
}
6576

77+
enum DeviceStringParsingState { kStart, kIndexStart, kIndexRest, kError };
78+
6679
Device::Device(const std::string& device_string) : Device(Type::CPU) {
6780
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
68-
auto colon_pos = device_string.find(':');
69-
std::string type_str = colon_pos == std::string::npos
70-
? device_string
71-
: device_string.substr(0, colon_pos);
72-
type_ = parse_type(type_str);
73-
index_ = -1;
74-
if (colon_pos != std::string::npos) {
75-
std::string index_str = device_string.substr(colon_pos + 1);
76-
try {
77-
index_ = static_cast<DeviceIndex>(std::stoi(index_str));
78-
} catch (const std::invalid_argument&) {
79-
PADDLE_THROW(::common::errors::InvalidArgument(
80-
"Invalid device index: '%s' is not a number.", index_str));
81-
} catch (const std::out_of_range&) {
82-
PADDLE_THROW(::common::errors::InvalidArgument(
83-
"Invalid device index: '%s' is out of range.", index_str));
81+
82+
std::string device_name, device_index_str;
83+
DeviceStringParsingState pstate = DeviceStringParsingState::kStart;
84+
85+
for (size_t i = 0;
86+
pstate != DeviceStringParsingState::kError && i < device_string.size();
87+
++i) {
88+
const char ch = device_string.at(i);
89+
const unsigned char uch = static_cast<unsigned char>(ch);
90+
switch (pstate) {
91+
case DeviceStringParsingState::kStart:
92+
if (ch != ':') {
93+
if (std::isalpha(uch) || ch == '_') {
94+
device_name.push_back(ch);
95+
} else {
96+
pstate = DeviceStringParsingState::kError;
97+
}
98+
} else {
99+
pstate = DeviceStringParsingState::kIndexStart;
100+
}
101+
break;
102+
case DeviceStringParsingState::kIndexStart:
103+
if (std::isdigit(uch)) {
104+
device_index_str.push_back(ch);
105+
pstate = DeviceStringParsingState::kIndexRest;
106+
} else {
107+
pstate = DeviceStringParsingState::kError;
108+
}
109+
break;
110+
case DeviceStringParsingState::kIndexRest:
111+
if (device_index_str.at(0) == '0') {
112+
pstate = DeviceStringParsingState::kError;
113+
break;
114+
}
115+
if (std::isdigit(uch)) {
116+
device_index_str.push_back(ch);
117+
} else {
118+
pstate = DeviceStringParsingState::kError;
119+
}
120+
break;
121+
case DeviceStringParsingState::kError:
122+
break;
123+
}
124+
}
125+
126+
const bool has_error = device_name.empty() ||
127+
pstate == DeviceStringParsingState::kError ||
128+
(pstate == DeviceStringParsingState::kIndexStart &&
129+
device_index_str.empty());
130+
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
131+
132+
try {
133+
if (!device_index_str.empty()) {
134+
index_ = static_cast<DeviceIndex>(std::stoi(device_index_str));
84135
}
136+
} catch (const std::exception&) {
137+
TORCH_CHECK(false,
138+
"Could not parse device index '",
139+
device_index_str,
140+
"' in device string '",
141+
device_string,
142+
"'");
85143
}
144+
type_ = parse_type(device_name);
145+
validate();
86146
}
87147

88148
std::string Device::str() const {

paddle/phi/api/include/compat/c10/core/Device.h

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ using gpuStream_t = hipStream_t;
2424
#endif
2525

2626
#include <c10/core/DeviceType.h>
27+
#include <c10/util/Exception.h>
2728

29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <functional>
32+
#include <iosfwd>
2833
#include <string>
2934
#include <utility>
3035

@@ -42,13 +47,19 @@ struct Device final {
4247
index_(place.GetType() == phi::AllocationType::CPU
4348
? static_cast<DeviceIndex>(-1)
4449
: place.GetDeviceId()),
45-
custom_device_type_(place.GetDeviceType()) {}
50+
custom_device_type_(place.GetDeviceType()) {
51+
validate();
52+
}
4653
Device(DeviceType type, DeviceIndex index = -1)
47-
: type_(type), index_(index) {} // NOLINT
54+
: type_(type), index_(index) { // NOLINT
55+
validate();
56+
}
4857
Device(DeviceType type, DeviceIndex index, std::string custom_device_type)
4958
: type_(type),
5059
index_(index),
51-
custom_device_type_(std::move(custom_device_type)) {} // NOLINT
60+
custom_device_type_(std::move(custom_device_type)) { // NOLINT
61+
validate();
62+
}
5263

5364
/// Constructs a `Device` from a string description, for convenience.
5465
/// The string supplied must follow the following schema:
@@ -63,10 +74,51 @@ struct Device final {
6374

6475
DeviceType type() const noexcept { return type_; }
6576

77+
bool operator!=(const Device& other) const noexcept {
78+
return !(*this == other);
79+
}
80+
81+
void set_index(DeviceIndex index) {
82+
index_ = index;
83+
validate();
84+
}
85+
6686
bool is_cuda() const noexcept { return type_ == DeviceType::CUDA; }
6787

88+
bool is_privateuseone() const noexcept {
89+
return type_ == DeviceType::PrivateUse1;
90+
}
91+
92+
bool is_mps() const noexcept { return false; }
93+
94+
bool is_hip() const noexcept { return false; }
95+
96+
bool is_ve() const noexcept { return false; }
97+
98+
bool is_xpu() const noexcept { return type_ == DeviceType::XPU; }
99+
100+
bool is_ipu() const noexcept { return type_ == DeviceType::IPU; }
101+
102+
bool is_xla() const noexcept { return false; }
103+
104+
bool is_mtia() const noexcept { return false; }
105+
106+
bool is_hpu() const noexcept { return false; }
107+
108+
bool is_lazy() const noexcept { return false; }
109+
110+
bool is_vulkan() const noexcept { return false; }
111+
112+
bool is_metal() const noexcept { return false; }
113+
114+
bool is_maia() const noexcept { return false; }
115+
116+
bool is_meta() const noexcept { return false; }
117+
68118
bool is_cpu() const noexcept { return type_ == DeviceType::CPU; }
69119

120+
bool supports_as_strided() const noexcept { return type_ != DeviceType::IPU; }
121+
70122
std::string str() const;
71123

72124
bool operator==(const Device& other) const noexcept {
@@ -96,12 +148,37 @@ struct Device final {
96148
DeviceType type_{DeviceType::CPU};
97149
DeviceIndex index_{-1};
98150
std::string custom_device_type_;
151+
152+
void validate() {
153+
#ifndef NDEBUG
154+
TORCH_INTERNAL_ASSERT(index_ >= -1,
155+
"Device index must be -1 or non-negative, got ",
156+
static_cast<int>(index_));
157+
TORCH_INTERNAL_ASSERT(!is_cpu() || index_ <= 0,
158+
"CPU device index must be -1 or zero, got ",
159+
static_cast<int>(index_));
160+
#endif
161+
}
99162
};
100163

101164
std::ostream& operator<<(std::ostream& stream, const Device& device);
102165

103166
} // namespace c10
104167

168+
namespace std {
169+
template <>
170+
struct hash<c10::Device> {
171+
size_t operator()(c10::Device d) const noexcept {
172+
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
173+
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
174+
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
175+
<< 16 |
176+
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
177+
return std::hash<uint32_t>{}(bits);
178+
}
179+
};
180+
} // namespace std
181+
105182
namespace at {
106183
using c10::Device;
107184
using c10::DeviceIndex;

paddle/phi/api/include/compat/c10/core/DeviceType.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <cstdint>
18+
#include <functional>
1719
#include <ostream>
1820

1921
#include "paddle/phi/common/place.h"
@@ -26,13 +28,15 @@ enum class DeviceType : int8_t {
2628
XPU = 12,
2729
IPU = 18,
2830
CUSTOM = 20,
31+
PrivateUse1 = CUSTOM,
2932
};
3033

3134
constexpr DeviceType kCUDA = DeviceType::CUDA;
3235
constexpr DeviceType kCPU = DeviceType::CPU;
3336
constexpr DeviceType kCUSTOM = DeviceType::CUSTOM;
3437
constexpr DeviceType kXPU = DeviceType::XPU;
3538
constexpr DeviceType kIPU = DeviceType::IPU;
39+
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
3640

3741
inline phi::AllocationType DeviceTypeToPhi(DeviceType d) {
3842
switch (d) {
@@ -103,12 +107,22 @@ inline std::ostream& operator<<(std::ostream& os, DeviceType d) {
103107

104108
} // namespace c10
105109

110+
namespace std {
111+
template <>
112+
struct hash<c10::DeviceType> {
113+
std::size_t operator()(c10::DeviceType k) const noexcept {
114+
return std::hash<int>()(static_cast<int>(k));
115+
}
116+
};
117+
} // namespace std
118+
106119
namespace at {
107120
using c10::DeviceType;
108121
using c10::kCPU;
109122
using c10::kCUDA;
110123
using c10::kCUSTOM;
111124
using c10::kIPU;
125+
using c10::kPrivateUse1;
112126
using c10::kXPU;
113127
} // namespace at
114128

@@ -118,5 +132,6 @@ using c10::kCPU;
118132
using c10::kCUDA;
119133
using c10::kCUSTOM;
120134
using c10::kIPU;
135+
using c10::kPrivateUse1;
121136
using c10::kXPU;
122137
} // namespace torch

paddle/phi/api/include/compat/c10/util/Exception.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@
3434
namespace c10 {
3535
#define TORCH_CHECK(COND, ...) PD_CHECK(COND, ##__VA_ARGS__);
3636
#define TORCH_INTERNAL_ASSERT(COND, ...) PD_CHECK(COND, ##__VA_ARGS__);
37-
#define TORCH_CHECK_OP(val1, val2, op) \
38-
do { \
39-
auto&& _val1 = (val1); \
40-
auto&& _val2 = (val2); \
41-
if (!(_val1 op _val2)) { \
42-
std::ostringstream _result; \
43-
_result << "Expected " #val1 " " #op " " #val2 " (" << _val1 << " " \
44-
<< #op << " " << _val2 << "), but got false"; \
45-
PD_THROW(_result.str()); \
46-
} \
37+
#define TORCH_CHECK_OP(val1, val2, op) \
38+
do { \
39+
auto&& _val1 = (val1); \
40+
auto&& _val2 = (val2); \
41+
if (!(_val1 op _val2)) { \
42+
std::ostringstream _result; \
43+
_result << "Check failed: " #val1 " " #op " " #val2 " (" << _val1 \
44+
<< " vs. " << _val2 << "). "; \
45+
PD_THROW(_result.str()); \
46+
} \
4747
} while (false);
4848

4949
// Check for a given boolean condition.

0 commit comments

Comments
 (0)