4
4
#ifndef HOOMD_DLPACK_EXTENSION_H_
5
5
#define HOOMD_DLPACK_EXTENSION_H_
6
6
7
+ #include < type_traits>
7
8
#include < vector>
8
9
9
- #include " SystemView .h"
10
+ #include " cxx11utils .h"
10
11
#include " dlpack/dlpack.h"
11
-
12
+ # include " hoomd/GlobalArray.h "
12
13
13
14
namespace dlext
14
15
{
15
16
17
+ using namespace hoomd ;
18
+
19
+ // { // Aliases
16
20
17
21
using DLManagedTensorPtr = DLManagedTensor*;
22
+ using DLManagedTensorDeleter = void (*)(DLManagedTensorPtr);
18
23
19
- using AccessLocation = access_location::Enum;
20
- const auto kOnHost = access_location::host;
21
- #ifdef ENABLE_CUDA
22
- const auto kOnDevice = access_location::device;
23
- #endif
24
+ template <typename T>
25
+ using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;
24
26
25
- using AccessMode = access_mode::Enum;
26
- const auto kRead = access_mode::read;
27
- const auto kReadWrite = access_mode::readwrite;
28
- const auto kOverwrite = access_mode::overwrite;
27
+ // } // Aliases
28
+
29
+ // { // Constants
29
30
30
31
constexpr uint8_t kBits = std::is_same<Scalar, float >::value ? 32 : 64 ;
31
32
32
- template <template <typename > class Array , typename T, typename Object>
33
- using PropertyGetter = const Array<T>& (Object::*)() const ;
33
+ constexpr DLManagedTensor kInvalidDLManagedTensor {
34
+ DLTensor {
35
+ nullptr , // data
36
+ DLDevice { kDLExtDev , -1 }, // device
37
+ -1 , // ndim
38
+ DLDataType { 0 , 0 , 0 }, // dtype
39
+ nullptr , // shape
40
+ nullptr , // stride
41
+ 0 // byte_offset
42
+ },
43
+ nullptr ,
44
+ nullptr
45
+ };
34
46
35
- template <typename T>
36
- using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;
47
+ // } // Constants
37
48
38
49
template <typename T>
39
50
struct DLDataBridge {
@@ -51,19 +62,19 @@ template <typename T>
51
62
using DLDataBridgeUPtr = std::unique_ptr<DLDataBridge<T>>;
52
63
53
64
template <typename T>
54
- void DLDataBridgeDeleter (DLManagedTensorPtr tensor)
65
+ void delete_bridge (DLManagedTensorPtr tensor)
55
66
{
56
67
if (tensor)
57
68
delete static_cast <DLDataBridge<T>*>(tensor->manager_ctx );
58
69
}
59
70
71
+ void do_not_delete (DLManagedTensorPtr tensor) { }
72
+
60
73
template <typename T>
61
74
inline void * opaque (T* data) { return static_cast <void *>(data); }
62
75
63
- inline DLDevice dldevice (const SystemView& sysview, bool gpu_flag)
64
- {
65
- return DLDevice { gpu_flag ? kDLCUDA : kDLCPU , sysview.get_device_id (gpu_flag) };
66
- }
76
+ template <typename T>
77
+ inline void * opaque (const T* data) { return (void *)(data); }
67
78
68
79
template <typename >
69
80
constexpr DLDataType dtype ();
@@ -78,19 +89,6 @@ constexpr DLDataType dtype<int3>() { return DLDataType {kDLInt, 32, 1}; }
78
89
template <>
79
90
constexpr DLDataType dtype<unsigned int >() { return DLDataType {kDLUInt , 32 , 1 }; }
80
91
81
- template <template <typename > class >
82
- unsigned int particle_number (const SystemView& sysview);
83
- template <>
84
- inline unsigned int particle_number<GlobalArray>(const SystemView& sysview)
85
- {
86
- return sysview.local_particle_number ();
87
- }
88
- template <>
89
- inline unsigned int particle_number<GlobalVector>(const SystemView& sysview)
90
- {
91
- return sysview.global_particle_number ();
92
- }
93
-
94
92
template <typename >
95
93
constexpr int64_t stride1 ();
96
94
template <>
@@ -104,132 +102,6 @@ constexpr int64_t stride1<int3>() { return 3; }
104
102
template <>
105
103
constexpr int64_t stride1<unsigned int >() { return 1 ; }
106
104
107
- template <template <typename > class A , typename T, typename O>
108
- DLManagedTensorPtr wrap (
109
- const SystemView& sysview, PropertyGetter<A, T, O> getter,
110
- AccessLocation requested_location, AccessMode mode,
111
- int64_t size2 = 1 , uint64_t offset = 0 , uint64_t stride1_offset = 0
112
- ) {
113
- assert ((size2 >= 1 )); // assert is a macro so the extra parentheses are requiered here
114
-
115
- auto location = sysview.is_gpu_enabled () ? requested_location : kOnHost ;
116
- auto handle = ArrayHandleUPtr<T>(
117
- new ArrayHandle<T>(INVOKE (*(sysview.particle_data ()), getter)(), location, mode)
118
- );
119
- auto bridge = DLDataBridgeUPtr<T>(new DLDataBridge<T>(handle));
120
-
121
- #ifdef ENABLE_CUDA
122
- auto gpu_flag = (location == kOnDevice );
123
- #else
124
- auto gpu_flag = false ;
125
- #endif
126
-
127
- bridge->tensor .manager_ctx = bridge.get ();
128
- bridge->tensor .deleter = DLDataBridgeDeleter<T>;
129
-
130
- auto & dltensor = bridge->tensor .dl_tensor ;
131
- dltensor.data = opaque (bridge->handle ->data );
132
- dltensor.device = dldevice (sysview, gpu_flag);
133
- dltensor.dtype = dtype<T>();
134
-
135
- auto & shape = bridge->shape ;
136
- shape.push_back (particle_number<A>(sysview));
137
- if (size2 > 1 )
138
- shape.push_back (size2);
139
-
140
- auto & strides = bridge->strides ;
141
- strides.push_back (stride1<T>() + stride1_offset);
142
- if (size2 > 1 )
143
- strides.push_back (1 );
144
-
145
- dltensor.ndim = shape.size ();
146
- dltensor.shape = reinterpret_cast <std::int64_t *>(shape.data ());
147
- dltensor.strides = reinterpret_cast <std::int64_t *>(strides.data ());
148
- dltensor.byte_offset = offset;
149
-
150
- return &(bridge.release ()->tensor );
151
- }
152
-
153
- inline DLManagedTensorPtr positions_types (
154
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
155
- ) {
156
- return wrap (sysview, &ParticleData::getPositions, location, mode, 4 );
157
- }
158
-
159
- inline DLManagedTensorPtr velocities_masses (
160
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
161
- ) {
162
- return wrap (sysview, &ParticleData::getVelocities, location, mode, 4 );
163
- }
164
-
165
- inline DLManagedTensorPtr orientations (
166
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
167
- ) {
168
- return wrap (sysview, &ParticleData::getOrientationArray, location, mode, 4 );
169
- }
170
-
171
- inline DLManagedTensorPtr angular_momenta (
172
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
173
- ) {
174
- return wrap (sysview, &ParticleData::getAngularMomentumArray, location, mode, 4 );
175
- }
176
-
177
- inline DLManagedTensorPtr moments_of_intertia (
178
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
179
- ) {
180
- return wrap (sysview, &ParticleData::getMomentsOfInertiaArray, location, mode, 3 );
181
- }
182
-
183
- inline DLManagedTensorPtr charges (
184
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
185
- ) {
186
- return wrap (sysview, &ParticleData::getCharges, location, mode);
187
- }
188
-
189
- inline DLManagedTensorPtr diameters (
190
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
191
- ) {
192
- return wrap (sysview, &ParticleData::getDiameters, location, mode);
193
- }
194
-
195
- inline DLManagedTensorPtr images (
196
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
197
- ) {
198
- return wrap (sysview, &ParticleData::getImages, location, mode, 3 );
199
- }
200
-
201
- inline DLManagedTensorPtr tags (
202
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
203
- ) {
204
- return wrap (sysview, &ParticleData::getTags, location, mode);
205
- }
206
-
207
- inline DLManagedTensorPtr rtags (
208
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
209
- ) {
210
- return wrap (sysview, &ParticleData::getRTags, location, mode);
211
- }
212
-
213
- inline DLManagedTensorPtr net_forces (
214
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
215
- ) {
216
- return wrap (sysview, &ParticleData::getNetForce, location, mode, 4 );
217
- }
218
-
219
- inline DLManagedTensorPtr net_torques (
220
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
221
- ) {
222
- return wrap (sysview, &ParticleData::getNetTorqueArray, location, mode, 4 );
223
- }
224
-
225
- inline DLManagedTensorPtr net_virial (
226
- const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
227
- ) {
228
- return wrap (sysview, &ParticleData::getNetVirial, location, mode, 6 );
229
- }
230
-
231
-
232
105
} // namespace dlext
233
106
234
-
235
107
#endif // HOOMD_DLPACK_EXTENSION_H_
0 commit comments