Skip to content

Commit b446008

Browse files
committed
[FFI] Make Error to be ABI invariant
1 parent 4cf0157 commit b446008

File tree

7 files changed

+42
-36
lines changed

7 files changed

+42
-36
lines changed

ffi/include/tvm/ffi/c_api.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ typedef struct {
224224
* \brief The traceback of the error.
225225
*/
226226
TVMFFIByteArray traceback;
227+
/*!
228+
* \brief Function handle to update the traceback of the error.
229+
* \param self The self object handle.
230+
* \param traceback The traceback to update.
231+
*/
232+
void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback);
227233
} TVMFFIErrorCell;
228234

229235
/*!
@@ -483,14 +489,6 @@ TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind,
483489
const TVMFFIByteArray* message,
484490
const TVMFFIByteArray* traceback);
485491

486-
/*!
487-
* \brief Update the traceback of an Error object.
488-
* \param obj The error handle.
489-
* \param traceback The traceback to update.
490-
*/
491-
TVM_FFI_DLL void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj,
492-
const TVMFFIByteArray* traceback);
493-
494492
/*!
495493
* \brief Check if there are any signals raised in the surrounding env.
496494
* \return 0 when success, nonzero when failure happens

ffi/include/tvm/ffi/error.h

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,26 +81,39 @@ struct EnvErrorAlreadySet : public std::exception {};
8181
*/
8282
class ErrorObj : public Object, public TVMFFIErrorCell {
8383
public:
84-
/*!
85-
* \brief Update the traceback of the error object.
86-
* \param traceback The traceback to update.
87-
*/
88-
void UpdateTraceback(const TVMFFIByteArray* traceback_str) {
89-
this->traceback_data_ = std::string(traceback_str->data, traceback_str->size);
90-
this->traceback = TVMFFIByteArray{this->traceback_data_.data(), this->traceback_data_.length()};
91-
}
92-
9384
static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError;
9485
static constexpr const char* _type_key = "object.Error";
9586

9687
TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object);
88+
};
89+
90+
namespace details {
91+
class ErrorObjFromStd : public ErrorObj {
92+
public:
93+
ErrorObjFromStd(std::string kind, std::string message, std::string traceback)
94+
: kind_data_(kind), message_data_(message), traceback_data_(traceback) {
95+
this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()};
96+
this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()};
97+
this->traceback = TVMFFIByteArray{traceback_data_.data(), traceback_data_.length()};
98+
this->update_traceback = UpdateTraceback;
99+
}
97100

98101
private:
99-
friend class Error;
102+
/*!
103+
* \brief Update the traceback of the error object.
104+
* \param traceback The traceback to update.
105+
*/
106+
static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback_str) {
107+
ErrorObjFromStd* obj = static_cast<ErrorObjFromStd*>(self);
108+
obj->traceback_data_ = std::string(traceback_str->data, traceback_str->size);
109+
obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), obj->traceback_data_.length()};
110+
}
111+
100112
std::string kind_data_;
101113
std::string message_data_;
102114
std::string traceback_data_;
103115
};
116+
} // namespace details
104117

105118
/*!
106119
* \brief Managed reference to ErrorObj
@@ -109,14 +122,7 @@ class ErrorObj : public Object, public TVMFFIErrorCell {
109122
class Error : public ObjectRef, public std::exception {
110123
public:
111124
Error(std::string kind, std::string message, std::string traceback) {
112-
ObjectPtr<ErrorObj> n = make_object<ErrorObj>();
113-
n->kind_data_ = std::move(kind);
114-
n->message_data_ = std::move(message);
115-
n->traceback_data_ = std::move(traceback);
116-
n->kind = TVMFFIByteArray{n->kind_data_.data(), n->kind_data_.length()};
117-
n->message = TVMFFIByteArray{n->message_data_.data(), n->message_data_.length()};
118-
n->traceback = TVMFFIByteArray{n->traceback_data_.data(), n->traceback_data_.length()};
119-
data_ = std::move(n);
125+
data_ = make_object<details::ErrorObjFromStd>(kind, message, traceback);
120126
}
121127

122128
Error(std::string kind, std::string message, const TVMFFIByteArray* traceback)
@@ -137,6 +143,11 @@ class Error : public ObjectRef, public std::exception {
137143
return std::string(obj->traceback.data, obj->traceback.size);
138144
}
139145

146+
void UpdateTraceback(const TVMFFIByteArray* traceback_str) {
147+
ErrorObj* obj = static_cast<ErrorObj*>(data_.get());
148+
obj->update_traceback(obj, traceback_str);
149+
}
150+
140151
const char* what() const noexcept(true) override {
141152
thread_local std::string what_data;
142153
ErrorObj* obj = static_cast<ErrorObj*>(data_.get());

ffi/include/tvm/ffi/function.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ class Function : public ObjectRef {
411411
return res.value();
412412
}
413413

414-
static Function GetGlobalRequired(const std::string& name) { return GetGlobalRequired(name); }
414+
static Function GetGlobalRequired(const std::string& name) {
415+
return GetGlobalRequired(std::string_view(name.data(), name.length()));
416+
}
415417

416418
static Function GetGlobalRequired(const String& name) {
417419
return GetGlobalRequired(std::string_view(name.data(), name.length()));

ffi/include/tvm/ffi/reflection/reflection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include <tvm/ffi/function.h>
2929
#include <tvm/ffi/type_traits.h>
3030

31+
#include <string>
32+
3133
namespace tvm {
3234
namespace ffi {
3335
namespace details {

ffi/src/ffi/error.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) {
6767
tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result);
6868
}
6969

70-
void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle obj, const TVMFFIByteArray* traceback) {
71-
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
72-
static_cast<tvm::ffi::ErrorObj*>(reinterpret_cast<tvm::ffi::Object*>(obj))
73-
->UpdateTraceback(traceback);
74-
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorUpdateTraceback);
75-
}
76-
7770
TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message,
7871
const TVMFFIByteArray* traceback) {
7972
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();

python/tvm/ffi/cython/base.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ cdef extern from "tvm/ffi/c_api.h":
128128
TVMFFIByteArray kind
129129
TVMFFIByteArray message
130130
TVMFFIByteArray traceback
131+
void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback)
131132

132133
ctypedef int (*TVMFFISafeCallType)(
133134
void* ctx, const TVMFFIAny* args, int32_t num_args,
@@ -144,7 +145,6 @@ cdef extern from "tvm/ffi/c_api.h":
144145
int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil
145146
void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil
146147
void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil
147-
void TVMFFIErrorUpdateTraceback(TVMFFIObjectHandle error, TVMFFIByteArray* traceback) nogil
148148
TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message,
149149
TVMFFIByteArray* traceback) nogil
150150
int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil

python/tvm/ffi/cython/error.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ cdef class Error(Object):
8989
The traceback to update.
9090
"""
9191
cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback))
92-
TVMFFIErrorUpdateTraceback(self.chandle, traceback_arg.cptr())
92+
TVMFFIErrorGetCellPtr(self.chandle).update_traceback(self.chandle, traceback_arg.cptr())
9393

9494
def py_error(self):
9595
"""

0 commit comments

Comments
 (0)