Skip to content

Commit 935059b

Browse files
authored
Reorganize code and remove code duplication (#10)
2 parents f64215d + faf6d03 commit 935059b

13 files changed

+503
-522
lines changed

dlext/include/DLExt.h

+31-159
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,47 @@
44
#ifndef HOOMD_DLPACK_EXTENSION_H_
55
#define HOOMD_DLPACK_EXTENSION_H_
66

7+
#include <type_traits>
78
#include <vector>
89

9-
#include "SystemView.h"
10+
#include "cxx11utils.h"
1011
#include "dlpack/dlpack.h"
11-
12+
#include "hoomd/GlobalArray.h"
1213

1314
namespace dlext
1415
{
1516

17+
using namespace hoomd;
18+
19+
// { // Aliases
1620

1721
using DLManagedTensorPtr = DLManagedTensor*;
22+
using DLManagedTensorDeleter = void (*)(DLManagedTensorPtr);
1823

19-
using AccessLocation = access_location::Enum;
20-
const auto kOnHost = access_location::host;
21-
#ifdef ENABLE_CUDA
22-
const auto kOnDevice = access_location::device;
23-
#endif
24+
template <typename T>
25+
using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;
2426

25-
using AccessMode = access_mode::Enum;
26-
const auto kRead = access_mode::read;
27-
const auto kReadWrite = access_mode::readwrite;
28-
const auto kOverwrite = access_mode::overwrite;
27+
// } // Aliases
28+
29+
// { // Constants
2930

3031
constexpr uint8_t kBits = std::is_same<Scalar, float>::value ? 32 : 64;
3132

32-
template <template <typename> class Array, typename T, typename Object>
33-
using PropertyGetter = const Array<T>& (Object::*)() const;
33+
constexpr DLManagedTensor kInvalidDLManagedTensor {
34+
DLTensor {
35+
nullptr, // data
36+
DLDevice { kDLExtDev, -1 }, // device
37+
-1, // ndim
38+
DLDataType { 0, 0, 0 }, // dtype
39+
nullptr, // shape
40+
nullptr, // stride
41+
0 // byte_offset
42+
},
43+
nullptr,
44+
nullptr
45+
};
3446

35-
template <typename T>
36-
using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;
47+
// } // Constants
3748

3849
template <typename T>
3950
struct DLDataBridge {
@@ -51,19 +62,19 @@ template <typename T>
5162
using DLDataBridgeUPtr = std::unique_ptr<DLDataBridge<T>>;
5263

5364
template <typename T>
54-
void DLDataBridgeDeleter(DLManagedTensorPtr tensor)
65+
void delete_bridge(DLManagedTensorPtr tensor)
5566
{
5667
if (tensor)
5768
delete static_cast<DLDataBridge<T>*>(tensor->manager_ctx);
5869
}
5970

71+
void do_not_delete(DLManagedTensorPtr tensor) { }
72+
6073
template <typename T>
6174
inline void* opaque(T* data) { return static_cast<void*>(data); }
6275

63-
inline DLDevice dldevice(const SystemView& sysview, bool gpu_flag)
64-
{
65-
return DLDevice { gpu_flag ? kDLCUDA : kDLCPU, sysview.get_device_id(gpu_flag) };
66-
}
76+
template <typename T>
77+
inline void* opaque(const T* data) { return (void*)(data); }
6778

6879
template <typename>
6980
constexpr DLDataType dtype();
@@ -78,19 +89,6 @@ constexpr DLDataType dtype<int3>() { return DLDataType {kDLInt, 32, 1}; }
7889
template <>
7990
constexpr DLDataType dtype<unsigned int>() { return DLDataType {kDLUInt, 32, 1}; }
8091

81-
template <template <typename> class>
82-
unsigned int particle_number(const SystemView& sysview);
83-
template <>
84-
inline unsigned int particle_number<GlobalArray>(const SystemView& sysview)
85-
{
86-
return sysview.local_particle_number();
87-
}
88-
template <>
89-
inline unsigned int particle_number<GlobalVector>(const SystemView& sysview)
90-
{
91-
return sysview.global_particle_number();
92-
}
93-
9492
template <typename>
9593
constexpr int64_t stride1();
9694
template <>
@@ -104,132 +102,6 @@ constexpr int64_t stride1<int3>() { return 3; }
104102
template <>
105103
constexpr int64_t stride1<unsigned int>() { return 1; }
106104

107-
template <template <typename> class A, typename T, typename O>
108-
DLManagedTensorPtr wrap(
109-
const SystemView& sysview, PropertyGetter<A, T, O> getter,
110-
AccessLocation requested_location, AccessMode mode,
111-
int64_t size2 = 1, uint64_t offset = 0, uint64_t stride1_offset = 0
112-
) {
113-
assert((size2 >= 1)); // assert is a macro so the extra parentheses are requiered here
114-
115-
auto location = sysview.is_gpu_enabled() ? requested_location : kOnHost;
116-
auto handle = ArrayHandleUPtr<T>(
117-
new ArrayHandle<T>(INVOKE(*(sysview.particle_data()), getter)(), location, mode)
118-
);
119-
auto bridge = DLDataBridgeUPtr<T>(new DLDataBridge<T>(handle));
120-
121-
#ifdef ENABLE_CUDA
122-
auto gpu_flag = (location == kOnDevice);
123-
#else
124-
auto gpu_flag = false;
125-
#endif
126-
127-
bridge->tensor.manager_ctx = bridge.get();
128-
bridge->tensor.deleter = DLDataBridgeDeleter<T>;
129-
130-
auto& dltensor = bridge->tensor.dl_tensor;
131-
dltensor.data = opaque(bridge->handle->data);
132-
dltensor.device = dldevice(sysview, gpu_flag);
133-
dltensor.dtype = dtype<T>();
134-
135-
auto& shape = bridge->shape;
136-
shape.push_back(particle_number<A>(sysview));
137-
if (size2 > 1)
138-
shape.push_back(size2);
139-
140-
auto& strides = bridge->strides;
141-
strides.push_back(stride1<T>() + stride1_offset);
142-
if (size2 > 1)
143-
strides.push_back(1);
144-
145-
dltensor.ndim = shape.size();
146-
dltensor.shape = reinterpret_cast<std::int64_t*>(shape.data());
147-
dltensor.strides = reinterpret_cast<std::int64_t*>(strides.data());
148-
dltensor.byte_offset = offset;
149-
150-
return &(bridge.release()->tensor);
151-
}
152-
153-
inline DLManagedTensorPtr positions_types(
154-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
155-
) {
156-
return wrap(sysview, &ParticleData::getPositions, location, mode, 4);
157-
}
158-
159-
inline DLManagedTensorPtr velocities_masses(
160-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
161-
) {
162-
return wrap(sysview, &ParticleData::getVelocities, location, mode, 4);
163-
}
164-
165-
inline DLManagedTensorPtr orientations(
166-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
167-
) {
168-
return wrap(sysview, &ParticleData::getOrientationArray, location, mode, 4);
169-
}
170-
171-
inline DLManagedTensorPtr angular_momenta(
172-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
173-
) {
174-
return wrap(sysview, &ParticleData::getAngularMomentumArray, location, mode, 4);
175-
}
176-
177-
inline DLManagedTensorPtr moments_of_intertia(
178-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
179-
) {
180-
return wrap(sysview, &ParticleData::getMomentsOfInertiaArray, location, mode, 3);
181-
}
182-
183-
inline DLManagedTensorPtr charges(
184-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
185-
) {
186-
return wrap(sysview, &ParticleData::getCharges, location, mode);
187-
}
188-
189-
inline DLManagedTensorPtr diameters(
190-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
191-
) {
192-
return wrap(sysview, &ParticleData::getDiameters, location, mode);
193-
}
194-
195-
inline DLManagedTensorPtr images(
196-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
197-
) {
198-
return wrap(sysview, &ParticleData::getImages, location, mode, 3);
199-
}
200-
201-
inline DLManagedTensorPtr tags(
202-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
203-
) {
204-
return wrap(sysview, &ParticleData::getTags, location, mode);
205-
}
206-
207-
inline DLManagedTensorPtr rtags(
208-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
209-
) {
210-
return wrap(sysview, &ParticleData::getRTags, location, mode);
211-
}
212-
213-
inline DLManagedTensorPtr net_forces(
214-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
215-
) {
216-
return wrap(sysview, &ParticleData::getNetForce, location, mode, 4);
217-
}
218-
219-
inline DLManagedTensorPtr net_torques(
220-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
221-
) {
222-
return wrap(sysview, &ParticleData::getNetTorqueArray, location, mode, 4);
223-
}
224-
225-
inline DLManagedTensorPtr net_virial(
226-
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
227-
) {
228-
return wrap(sysview, &ParticleData::getNetVirial, location, mode, 6);
229-
}
230-
231-
232105
} // namespace dlext
233106

234-
235107
#endif // HOOMD_DLPACK_EXTENSION_H_

dlext/include/Sampler.h

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SPDX-License-Identifier: MIT
2+
// This file is part of `hoomd-dlext`, see LICENSE.md
3+
4+
#ifndef DLEXT_SAMPLER_H_
5+
#define DLEXT_SAMPLER_H_
6+
7+
#include "SystemView.h"
8+
#include "hoomd/HalfStepHook.h"
9+
10+
namespace dlext
11+
{
12+
13+
using TimeStep = unsigned int;
14+
15+
template <typename ExternalUpdater, template <typename> class Wrapper>
16+
class DEFAULT_VISIBILITY Sampler : public HalfStepHook {
17+
public:
18+
//! Constructor
19+
Sampler(
20+
SystemView sysview,
21+
ExternalUpdater update_callback,
22+
AccessLocation location,
23+
AccessMode mode
24+
);
25+
void setSystemDefinition(SystemDefinitionSPtr sysdef) override
26+
{
27+
_sysview = SystemView(sysdef);
28+
}
29+
void update(TimeStep timestep) override
30+
{
31+
forward_data(_update_callback, _location, _mode, timestep);
32+
}
33+
34+
const SystemView& system_view() const;
35+
36+
//! Wraps the system positions, velocities, reverse tags, images and forces as
37+
//! DLPack tensors and passes them to the external function `callback`.
38+
//!
39+
//! The (non-typed) signature of `callback` is expected to be
40+
//! callback(positions, velocities, rtags, images, forces, n)
41+
//! where `n` ìs an additional `TimeStep` parameter.
42+
//!
43+
//! The data for the particles information is requested at the given `location`
44+
//! and access `mode`. NOTE: Forces are always passed in readwrite mode.
45+
template <typename Callback>
46+
void forward_data(Callback callback, AccessLocation location, AccessMode mode, TimeStep n)
47+
{
48+
auto pos_capsule = Wrapper<PositionsTypes>::wrap(_sysview, location, mode);
49+
auto vel_capsule = Wrapper<VelocitiesMasses>::wrap(_sysview, location, mode);
50+
auto rtags_capsule = Wrapper<RTags>::wrap(_sysview, location, mode);
51+
auto img_capsule = Wrapper<Images>::wrap(_sysview, location, mode);
52+
auto force_capsule = Wrapper<NetForces>::wrap(_sysview, location, kReadWrite);
53+
54+
callback(pos_capsule, vel_capsule, rtags_capsule, img_capsule, force_capsule, n);
55+
}
56+
57+
private:
58+
SystemView _sysview;
59+
ExternalUpdater _update_callback;
60+
AccessLocation _location;
61+
AccessMode _mode;
62+
};
63+
64+
template <typename ExternalUpdater, template <typename> class Wrapper>
65+
Sampler<ExternalUpdater, Wrapper>::Sampler(
66+
SystemView sysview, ExternalUpdater update, AccessLocation location, AccessMode mode
67+
)
68+
: _sysview { sysview }
69+
, _update_callback { update }
70+
, _location { location }
71+
, _mode { mode }
72+
{ }
73+
74+
template <typename ExternalUpdater, template <typename> class Wrapper>
75+
const SystemView& Sampler<ExternalUpdater, Wrapper>::system_view() const
76+
{
77+
return _sysview;
78+
}
79+
80+
} // namespace dlext
81+
82+
#endif // DLEXT_SAMPLER_H_

0 commit comments

Comments
 (0)