Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ third_party/
*~
bazel-*
.humanize
.codex

build_*
# clion workspace.
Expand Down
110 changes: 85 additions & 25 deletions paddle/phi/api/include/compat/c10/core/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

#include <c10/core/Device.h>
#include <c10/util/Exception.h>

#include <algorithm>
#include <array>
#include <cctype>
#include <exception>
#include <string>

#include "paddle/common/enforce.h"

namespace c10 {
Expand All @@ -45,44 +51,98 @@ const char* DeviceTypeToString(DeviceType type) {

DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<const char*, DeviceType>,
static_cast<size_t>(4)>
static_cast<size_t>(5)>
types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"ipu", DeviceType::IPU},
{"xpu", DeviceType::XPU},
{"privateuseone", DeviceType::PrivateUse1},
}};
for (const auto& type_pair : types) {
if (device_string == type_pair.first) {
return type_pair.second;
}
auto device = std::find_if(
types.begin(),
types.end(),
[&device_string](const std::pair<const char*, DeviceType>& p) {
return p.first && p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
PADDLE_THROW(::common::errors::InvalidArgument(
"Unknown device type: '%s'. Supported device types are ",
"'cpu', 'cuda', 'ipu', and 'xpu'.",
device_string));
TORCH_CHECK(false,
"Expected one of cpu, cuda, ipu, xpu, privateuseone device type "
"at start of device string: ",
device_string);
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_type() can reach the end of a non-void function without returning a DeviceType (after the TORCH_CHECK(false, ...)). Even if TORCH_CHECK throws at runtime, this is still UB in the language rules and may trigger -Wreturn-type/-Werror build failures. Add an explicit return (or an unreachable annotation) after the check to satisfy the compiler.

Suggested change
device_string);
device_string);
return DeviceType::CPU; // Unreachable, added to satisfy compiler.

Copilot uses AI. Check for mistakes.
}

enum DeviceStringParsingState { kStart, kIndexStart, kIndexRest, kError };

Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
auto colon_pos = device_string.find(':');
std::string type_str = colon_pos == std::string::npos
? device_string
: device_string.substr(0, colon_pos);
type_ = parse_type(type_str);
index_ = -1;
if (colon_pos != std::string::npos) {
std::string index_str = device_string.substr(colon_pos + 1);
try {
index_ = static_cast<DeviceIndex>(std::stoi(index_str));
} catch (const std::invalid_argument&) {
PADDLE_THROW(::common::errors::InvalidArgument(
"Invalid device index: '%s' is not a number.", index_str));
} catch (const std::out_of_range&) {
PADDLE_THROW(::common::errors::InvalidArgument(
"Invalid device index: '%s' is out of range.", index_str));

std::string device_name, device_index_str;
DeviceStringParsingState pstate = DeviceStringParsingState::kStart;

for (size_t i = 0;
pstate != DeviceStringParsingState::kError && i < device_string.size();
++i) {
const char ch = device_string.at(i);
const unsigned char uch = static_cast<unsigned char>(ch);
switch (pstate) {
case DeviceStringParsingState::kStart:
if (ch != ':') {
if (std::isalpha(uch) || ch == '_') {
device_name.push_back(ch);
} else {
pstate = DeviceStringParsingState::kError;
}
} else {
pstate = DeviceStringParsingState::kIndexStart;
}
break;
case DeviceStringParsingState::kIndexStart:
if (std::isdigit(uch)) {
device_index_str.push_back(ch);
pstate = DeviceStringParsingState::kIndexRest;
} else {
pstate = DeviceStringParsingState::kError;
}
break;
case DeviceStringParsingState::kIndexRest:
if (device_index_str.at(0) == '0') {
pstate = DeviceStringParsingState::kError;
break;
}
if (std::isdigit(uch)) {
device_index_str.push_back(ch);
} else {
pstate = DeviceStringParsingState::kError;
}
break;
case DeviceStringParsingState::kError:
break;
}
}

const bool has_error = device_name.empty() ||
pstate == DeviceStringParsingState::kError ||
(pstate == DeviceStringParsingState::kIndexStart &&
device_index_str.empty());
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");

try {
if (!device_index_str.empty()) {
index_ = static_cast<DeviceIndex>(std::stoi(device_index_str));
}
} catch (const std::exception&) {
TORCH_CHECK(false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
}
type_ = parse_type(device_name);
validate();
}

std::string Device::str() const {
Expand Down
83 changes: 80 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ using gpuStream_t = hipStream_t;
#endif

#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>

#include <cstddef>
#include <cstdint>
#include <functional>
#include <iosfwd>
#include <string>
#include <utility>

Expand All @@ -42,13 +47,19 @@ struct Device final {
index_(place.GetType() == phi::AllocationType::CPU
? static_cast<DeviceIndex>(-1)
: place.GetDeviceId()),
custom_device_type_(place.GetDeviceType()) {}
custom_device_type_(place.GetDeviceType()) {
validate();
}
Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {} // NOLINT
: type_(type), index_(index) { // NOLINT
validate();
}
Device(DeviceType type, DeviceIndex index, std::string custom_device_type)
: type_(type),
index_(index),
custom_device_type_(std::move(custom_device_type)) {} // NOLINT
custom_device_type_(std::move(custom_device_type)) { // NOLINT
validate();
}

/// Constructs a `Device` from a string description, for convenience.
/// The string supplied must follow the following schema:
Expand All @@ -63,10 +74,51 @@ struct Device final {

DeviceType type() const noexcept { return type_; }

bool operator!=(const Device& other) const noexcept {
return !(*this == other);
}

void set_index(DeviceIndex index) {
index_ = index;
validate();
}

bool is_cuda() const noexcept { return type_ == DeviceType::CUDA; }

bool is_privateuseone() const noexcept {
return type_ == DeviceType::PrivateUse1;
}

bool is_mps() const noexcept { return false; }

bool is_hip() const noexcept { return false; }

bool is_ve() const noexcept { return false; }

bool is_xpu() const noexcept { return type_ == DeviceType::XPU; }

bool is_ipu() const noexcept { return type_ == DeviceType::IPU; }

bool is_xla() const noexcept { return false; }

bool is_mtia() const noexcept { return false; }

bool is_hpu() const noexcept { return false; }

bool is_lazy() const noexcept { return false; }

bool is_vulkan() const noexcept { return false; }

bool is_metal() const noexcept { return false; }

bool is_maia() const noexcept { return false; }

bool is_meta() const noexcept { return false; }

bool is_cpu() const noexcept { return type_ == DeviceType::CPU; }

bool supports_as_strided() const noexcept { return type_ != DeviceType::IPU; }

std::string str() const;

bool operator==(const Device& other) const noexcept {
Expand Down Expand Up @@ -96,12 +148,37 @@ struct Device final {
DeviceType type_{DeviceType::CPU};
DeviceIndex index_{-1};
std::string custom_device_type_;

void validate() {
#ifndef NDEBUG
TORCH_INTERNAL_ASSERT(index_ >= -1,
"Device index must be -1 or non-negative, got ",
static_cast<int>(index_));
TORCH_INTERNAL_ASSERT(!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
static_cast<int>(index_));
#endif
}
};

std::ostream& operator<<(std::ostream& stream, const Device& device);

} // namespace c10

namespace std {
template <>
struct hash<c10::Device> {
size_t operator()(c10::Device d) const noexcept {
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
Comment on lines +168 to +178
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::hash<c10::Device>::operator() takes the key by value, which will copy c10::Device (including its std::string custom_device_type_) every time the hash is computed. This is avoidable overhead for unordered_map/unordered_set usage. Take the parameter as const c10::Device& instead, and consider incorporating custom_device_type_ into the hash to reduce collisions since operator== includes it.

Copilot uses AI. Check for mistakes.
};
} // namespace std

namespace at {
using c10::Device;
using c10::DeviceIndex;
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/api/include/compat/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <cstdint>
#include <functional>
#include <ostream>

#include "paddle/phi/common/place.h"
Expand All @@ -26,13 +28,15 @@ enum class DeviceType : int8_t {
XPU = 12,
IPU = 18,
CUSTOM = 20,
PrivateUse1 = CUSTOM,
};

constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUSTOM = DeviceType::CUSTOM;
constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;

inline phi::AllocationType DeviceTypeToPhi(DeviceType d) {
switch (d) {
Expand Down Expand Up @@ -103,12 +107,22 @@ inline std::ostream& operator<<(std::ostream& os, DeviceType d) {

} // namespace c10

namespace std {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const noexcept {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std

namespace at {
using c10::DeviceType;
using c10::kCPU;
using c10::kCUDA;
using c10::kCUSTOM;
using c10::kIPU;
using c10::kPrivateUse1;
using c10::kXPU;
} // namespace at

Expand All @@ -118,5 +132,6 @@ using c10::kCPU;
using c10::kCUDA;
using c10::kCUSTOM;
using c10::kIPU;
using c10::kPrivateUse1;
using c10::kXPU;
} // namespace torch
20 changes: 10 additions & 10 deletions paddle/phi/api/include/compat/c10/util/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@
namespace c10 {
#define TORCH_CHECK(COND, ...) PD_CHECK(COND, ##__VA_ARGS__);
#define TORCH_INTERNAL_ASSERT(COND, ...) PD_CHECK(COND, ##__VA_ARGS__);
#define TORCH_CHECK_OP(val1, val2, op) \
do { \
auto&& _val1 = (val1); \
auto&& _val2 = (val2); \
if (!(_val1 op _val2)) { \
std::ostringstream _result; \
_result << "Expected " #val1 " " #op " " #val2 " (" << _val1 << " " \
<< #op << " " << _val2 << "), but got false"; \
PD_THROW(_result.str()); \
} \
#define TORCH_CHECK_OP(val1, val2, op) \
do { \
auto&& _val1 = (val1); \
auto&& _val2 = (val2); \
if (!(_val1 op _val2)) { \
std::ostringstream _result; \
_result << "Check failed: " #val1 " " #op " " #val2 " (" << _val1 \
<< " vs. " << _val2 << "). "; \
PD_THROW(_result.str()); \
} \
} while (false);

// Check for a given boolean condition.
Expand Down
Loading
Loading