Skip to content

Commit d9170fa

Browse files
Atraxusichinii
authored andcommitted
feat(cli): add support for recording events of specific elements only
1 parent c646ed7 commit d9170fa

File tree

11 files changed

+36
-24
lines changed

11 files changed

+36
-24
lines changed

Intern/rayx-core/src/Shader/DynamicElements.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void dynamicElements(const int gid, const InvState& inv, OutputEvents& outputEve
6060
}
6161

6262
// write ray in local element coordinates to global memory
63-
if (inv.recordElementIndex < 0 || col.elementIndex == inv.recordElementIndex) {
63+
if (!inv.recordMask || inv.recordMask[col.elementIndex]) {
6464
outputEvents.events[gid * inv.maxEvents + numRecorded] = ray;
6565
++numRecorded;
6666
}

Intern/rayx-core/src/Shader/InvocationState.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ struct RAYX_API InvState {
1717
int batchSize;
1818
int batchStartRayIndex;
1919
int maxEvents;
20-
int recordElementIndex; //< Index of element, for which to record events. Others are discarded. -1 to record events of all elements.
2120
double randomSeed;
2221
Sequential sequential = Sequential::No;
2322

2423
OpticalElement* elements;
2524
int numElements;
2625
int* materialIndices;
2726
double* materialTables;
27+
bool* recordMask; //< Mask that decides which elements to record events for (array length is numElements)
2828
Ray* rays;
2929
};
3030

Intern/rayx-core/src/Tracer/DeviceTracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class RAYX_API DeviceTracer {
1717
public:
1818
virtual ~DeviceTracer() = default;
1919

20-
virtual RaySoA trace(const Group&, Sequential sequential, const int maxBatchSize, const int maxEvents, const int recordElementIndex,
20+
virtual RaySoA trace(const Group&, Sequential sequential, const int maxBatchSize, const int maxEvents, std::shared_ptr<bool[]> recordMask,
2121
const RayAttrFlag attr) = 0;
2222
};
2323

Intern/rayx-core/src/Tracer/MegaKernelTracer.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class MegaKernelTracer : public DeviceTracer {
150150

151151
public:
152152
virtual RaySoA trace(const Group& beamline, const Sequential sequential, const int maxBatchSize, const int maxEvents,
153-
const int recordElementIndex, const RayAttrFlag attr) override {
153+
std::shared_ptr<bool[]> recordMask, const RayAttrFlag attr) override {
154154
RAYX_PROFILE_FUNCTION_STDOUT();
155155

156156
const auto platformHost = alpaka::PlatformCpu{};
@@ -199,8 +199,7 @@ class MegaKernelTracer : public DeviceTracer {
199199
alpaka::memcpy(q, *m_resources.d_rays, raysViewBatch);
200200

201201
// trace current batch
202-
traceBatch(devAcc, q, conf.numElements, conf.numRaysTotal, batchSize, batchStartRayIndex, maxEvents, recordElementIndex, randomSeed,
203-
sequential);
202+
traceBatch(devAcc, q, conf.numElements, conf.numRaysTotal, batchSize, batchStartRayIndex, maxEvents, recordMask, randomSeed, sequential);
204203

205204
// prefix sum on compactEventCounts to get compactEventOffsets
206205
alpaka::memcpy(q, alpaka::createView(devHost, compactEventCounts, batchSize), *m_resources.d_compactEventCounts, batchSize);
@@ -260,7 +259,7 @@ class MegaKernelTracer : public DeviceTracer {
260259
private:
261260
template <typename DevAcc, typename Queue>
262261
void traceBatch(DevAcc devAcc, Queue q, int numElements, int numRaysTotal, int batchSize, int batchStartRayIndex, int maxEvents,
263-
int recordElementIndex, double randomSeed, Sequential sequential) {
262+
std::shared_ptr<bool[]> recordMask, double randomSeed, Sequential sequential) {
264263
RAYX_PROFILE_FUNCTION_STDOUT();
265264

266265
// inputs
@@ -270,7 +269,6 @@ class MegaKernelTracer : public DeviceTracer {
270269
.batchSize = batchSize,
271270
.batchStartRayIndex = batchStartRayIndex,
272271
.maxEvents = maxEvents,
273-
.recordElementIndex = recordElementIndex,
274272
.randomSeed = randomSeed,
275273
.sequential = sequential,
276274

@@ -279,6 +277,7 @@ class MegaKernelTracer : public DeviceTracer {
279277
.numElements = numElements,
280278
.materialIndices = alpaka::getPtrNative(*m_resources.d_materialIndices),
281279
.materialTables = alpaka::getPtrNative(*m_resources.d_materialTable),
280+
.recordMask = recordMask.get(),
282281
.rays = alpaka::getPtrNative(*m_resources.d_rays),
283282
};
284283

Intern/rayx-core/src/Tracer/Tracer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ Tracer::Tracer(const DeviceConfig& deviceConfig) {
5151
}
5252
}
5353

54-
RaySoA Tracer::trace(const Group& group, Sequential sequential, uint64_t maxBatchSize, uint32_t maxEvents, int32_t recordElementIndex,
54+
RaySoA Tracer::trace(const Group& group, Sequential sequential, uint64_t maxBatchSize, uint32_t maxEvents, std::shared_ptr<bool[]> recordMask,
5555
RayAttrFlag attr) {
5656
// in sequential tracing, maxEvents should be equal to the number of elements
5757
if (sequential == Sequential::Yes) maxEvents = group.numElements();
5858

59-
return m_deviceTracer->trace(group, sequential, static_cast<int>(maxBatchSize), static_cast<int>(maxEvents), recordElementIndex, attr);
59+
return m_deviceTracer->trace(group, sequential, static_cast<int>(maxBatchSize), static_cast<int>(maxEvents), recordMask, attr);
6060
}
6161

6262
int Tracer::defaultMaxEvents(const Group* group) {

Intern/rayx-core/src/Tracer/Tracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class RAYX_API Tracer {
2929
// This will call the trace implementation of a subclass
3030
// See `BundleHistory` for information about the return value.
3131
// `max_batch_size` corresponds to the maximal number of rays that will be put into `traceRaw` in one batch.
32-
RaySoA trace(const Group& group, Sequential sequential, uint64_t maxBatchSize, uint32_t maxEvents, int32_t recordElementIndex,
32+
RaySoA trace(const Group& group, Sequential sequential, uint64_t maxBatchSize, uint32_t maxEvents, std::shared_ptr<bool[]> recordMask = nullptr,
3333
RayAttrFlag attr = RayAttrFlag::All);
3434

3535
static int defaultMaxEvents(const Group* group);

Intern/rayx-core/tests/setupTests.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void writeToOutputCSV(const RAYX::BundleHistory& hist, std::string filename) {
7777

7878
RAYX::BundleHistory traceRML(std::string filename) {
7979
const auto beamline = loadBeamline(filename);
80-
const auto rays = tracer->trace(beamline, Sequential::No, DEFAULT_BATCH_SIZE, beamline.numElements() + 2, -1);
80+
const auto rays = tracer->trace(beamline, Sequential::No, DEFAULT_BATCH_SIZE, beamline.numElements() + 2);
8181
return raySoAToBundleHistory(rays);
8282
}
8383

@@ -177,7 +177,7 @@ std::optional<RAYX::Ray> lastSequentialHit(RayHistory ray_hist, uint32_t beamlin
177177
// returns the rayx rays converted to be ray-UI compatible.
178178
std::vector<RAYX::Ray> rayUiCompat(std::string filename, Sequential seq) {
179179
const auto beamline = loadBeamline(filename);
180-
const auto rays = tracer->trace(beamline, seq, DEFAULT_BATCH_SIZE, beamline.numElements() + 2, -1);
180+
const auto rays = tracer->trace(beamline, seq, DEFAULT_BATCH_SIZE, beamline.numElements() + 2);
181181
const auto hist = raySoAToBundleHistory(rays);
182182

183183
std::vector<RAYX::Ray> out;

Intern/rayx-ui/src/Simulator.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ void Simulator::runSimulation() {
1717
m_maxEvents = RAYX::Tracer::defaultMaxEvents(&m_Beamline);
1818
}
1919

20-
constexpr int RECORD_ALL_ELEMENTS = -1;
21-
const auto rays =
22-
m_Tracer->trace(m_Beamline, m_seq, m_max_batch_size, m_maxEvents, RECORD_ALL_ELEMENTS); // TODO: implement recordElementIndex for GUI?
20+
const auto rays = m_Tracer->trace(m_Beamline, m_seq, m_max_batch_size, m_maxEvents);
2321
const auto bundleHist = RAYX::raySoAToBundleHistory(rays);
2422

2523
bool notEnoughEvents = false;

Intern/rayx/src/CommandParser.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ CommandParser::CommandParser(int _argc, char* const* _argv) : m_cli11{std::make_
2121
m_cli11->add_flag(_name, *static_cast<bool*>(option.second.option_flag), _description);
2222
} else if (_type == OptionType::STRING) {
2323
m_cli11->add_option(_name, *static_cast<std::string*>(option.second.option_flag), _description);
24-
} else if (_type == OptionType::BOOL_STRING) { // Discarded
25-
m_cli11->add_flag(_name, *static_cast<bool*>(option.second.option_flag), _description);
2624
} else if (_type == OptionType::INT) {
2725
m_cli11->add_option(_name, *static_cast<int*>(option.second.option_flag), _description);
26+
} else if (_type == OptionType::INT_VEC) {
27+
m_cli11->add_option(_name, *static_cast<std::vector<int>*>(option.second.option_flag), _description)
28+
->expected(-1); // allow any number of ints
2829
}
2930
}
3031

Intern/rayx/src/CommandParser.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class CommandParser {
4141
bool m_verbose = false; // --verbose (Verbose)
4242
std::string m_format; // --format
4343
int m_maxEvents = -1; // -m (max events)
44-
int m_recordElementIndex = -1; // -R --record-element (element index)
45-
std::string m_dump = ""; // -D (dump)
44+
std::vector<int> m_recordIndices = {};
45+
std::string m_dump = ""; // -D (dump)
4646
} m_args;
4747

4848
static inline void getVersion() {
@@ -61,7 +61,7 @@ class CommandParser {
6161

6262
private:
6363
int m_cli11_return;
64-
enum OptionType { BOOL, INT, STRING, BOOL_STRING };
64+
enum OptionType { BOOL, INT, STRING, INT_VEC };
6565
struct Options {
6666
// CLI::Option cli11_option;
6767
const OptionType type;
@@ -97,7 +97,7 @@ class CommandParser {
9797
"\"",
9898
&(m_args.m_format)}},
9999
{'m', {OptionType::INT, "maxEvents", "Maximum number of events per ray", &(m_args.m_maxEvents)}},
100-
{'R', {OptionType::INT, "record-element", "Record events only for a specifc element", &(m_args.m_recordElementIndex)}},
100+
{'R', {OptionType::INT_VEC, "record-indices", "Record events only for specifc elements", &(m_args.m_recordIndices)}},
101101
{'D', {OptionType::STRING, "dump", "Dump the meta data of a file (h5 or rml)", &(m_args.m_dump)}},
102102
};
103103
};

0 commit comments

Comments
 (0)