@@ -43,6 +43,8 @@ struct Resources {
4343 // resources per beamline. constant per beamline
4444 // / beamline elements
4545 Buf<OpticalElement> d_elements;
46+ // / mask for which elements to record events
47+ Buf<bool > d_recordMask;
4648 // / all rays generated from all light sources
4749 std::vector<Ray> h_rays;
4850
@@ -75,7 +77,7 @@ struct Resources {
7577
7678 // / update resources
7779 template <typename Queue>
78- Config update (Queue q, const Group& group, int maxEvents, int maxBatchSize) {
80+ Config update (Queue q, const Group& group, int maxEvents, int maxBatchSize, const std::vector< bool >& recordMask ) {
7981 RAYX_PROFILE_FUNCTION_STDOUT ();
8082
8183 const auto platformHost = alpaka::PlatformCpu{};
@@ -98,6 +100,16 @@ struct Resources {
98100 allocBuf (q, d_elements, numElements);
99101 alpaka::memcpy (q, *d_elements, alpaka::createView (devHost, elements, numElements));
100102
103+ // record mask
104+ assert (recordMask.size () == static_cast <size_t >(numElements));
105+ allocBuf (q, d_recordMask, numElements);
106+ std::unique_ptr<bool []> tmpHostMask (new bool [numElements]);
107+ for (int i = 0 ; i < numElements; ++i) {
108+ tmpHostMask[i] = recordMask[i];
109+ }
110+ auto hostView = alpaka::createView (devHost, tmpHostMask.get (), numElements);
111+ alpaka::memcpy (q, *d_recordMask, hostView);
112+
101113 // input rays
102114 h_rays = group.compileSources (1 ); // TODO: generate rays on device
103115 const auto numRaysTotal = static_cast <int >(h_rays.size ());
@@ -149,8 +161,13 @@ class MegaKernelTracer : public DeviceTracer {
149161 Resources<Acc> m_resources;
150162
151163 public:
152- virtual RaySoA trace (const Group& beamline, const Sequential sequential, const int maxBatchSize, const int maxEvents,
153- std::shared_ptr<bool []> recordMask, const RayAttrFlag attr) override {
164+ virtual RaySoA trace (
165+ const Group& beamline,
166+ Sequential sequential,
167+ const int maxBatchSize,
168+ const int maxEvents,
169+ const std::vector<bool >& recordMask,
170+ const RayAttrFlag attr) override {
154171 RAYX_PROFILE_FUNCTION_STDOUT ();
155172
156173 const auto platformHost = alpaka::PlatformCpu{};
@@ -160,7 +177,7 @@ class MegaKernelTracer : public DeviceTracer {
160177 using Queue = alpaka::Queue<Acc, alpaka::Blocking>;
161178 auto q = Queue (devAcc);
162179
163- const auto conf = m_resources.update (q, beamline, maxEvents, maxBatchSize);
180+ const auto conf = m_resources.update (q, beamline, maxEvents, maxBatchSize, recordMask );
164181 const auto randomSeed = randomDouble ();
165182
166183 RAYX_VERB << " tracing beamline:" ;
@@ -199,9 +216,8 @@ class MegaKernelTracer : public DeviceTracer {
199216 alpaka::memcpy (q, *m_resources.d_rays , raysViewBatch);
200217
201218 // trace current batch
202- traceBatch (devAcc, q, conf.numElements , conf.numRaysTotal , batchSize, batchStartRayIndex, maxEvents, recordMask, randomSeed, sequential);
219+ traceBatch (devAcc, q, conf.numElements , conf.numRaysTotal , batchSize, batchStartRayIndex, maxEvents, randomSeed, sequential);
203220
204- // prefix sum on compactEventCounts to get compactEventOffsets
205221 alpaka::memcpy (q, alpaka::createView (devHost, compactEventCounts, batchSize), *m_resources.d_compactEventCounts , batchSize);
206222 std::exclusive_scan (compactEventCounts.begin (), compactEventCounts.begin () + batchSize, compactEventOffsets.begin (), 0 );
207223 alpaka::memcpy (q, *m_resources.d_compactEventOffsets , alpaka::createView (devHost, compactEventOffsets, batchSize), batchSize);
@@ -259,7 +275,7 @@ class MegaKernelTracer : public DeviceTracer {
259275 private:
260276 template <typename DevAcc, typename Queue>
261277 void traceBatch (DevAcc devAcc, Queue q, int numElements, int numRaysTotal, int batchSize, int batchStartRayIndex, int maxEvents,
262- std::shared_ptr< bool []> recordMask, double randomSeed, Sequential sequential) {
278+ double randomSeed, Sequential sequential) {
263279 RAYX_PROFILE_FUNCTION_STDOUT ();
264280
265281 // inputs
@@ -277,7 +293,7 @@ class MegaKernelTracer : public DeviceTracer {
277293 .numElements = numElements,
278294 .materialIndices = alpaka::getPtrNative (*m_resources.d_materialIndices ),
279295 .materialTables = alpaka::getPtrNative (*m_resources.d_materialTable ),
280- .recordMask = recordMask. get ( ),
296+ .recordMask = alpaka::getPtrNative (*m_resources. d_recordMask ),
281297 .rays = alpaka::getPtrNative (*m_resources.d_rays ),
282298 };
283299
0 commit comments