Skip to content

Commit 1edea06

Browse files
authored
refactor(sysrap): make shared RNG use explicit via traits (#300)
This removes fragile include-order coupling around the global `RNG` alias. Shared headers like `sysrap/storch.h` were implicitly depending on callers to define `using RNG = ...` before inclusion, which made CPU mock paths and CUDA paths easy to break when a header pulled in `sysrap/srng.h` directly. The change makes `RNG` selection explicit at call sites by templating shared generation functions over the `RNG` type and routing random access through `srng_traits`. CPU code can now pass `srngcpu` directly, while CUDA code continues to use the selected device `RNG`. This keeps CPU mock generation and GPU generation sharing the same implementation without relying on hidden global alias state.
1 parent 1760210 commit 1edea06

15 files changed

Lines changed: 149 additions & 87 deletions

sysrap/SEvent.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
#include "stran.h"
88
#include "sframe.h"
99

10-
#include "srngcpu.h"
11-
using RNG = srngcpu ;
12-
1310
#include "storch.h"
1411
#include "scerenkov.h"
1512
#include "sscint.h"
@@ -414,7 +411,3 @@ std::string SEvent::DescSeed( const int* seed, int num_seed, int edgeitems ) //
414411
std::string s = ss.str();
415412
return s ;
416413
}
417-
418-
419-
420-

sysrap/SGenerate.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ struct SGenerate
3131
#include "sphoton.h"
3232

3333
#include "srngcpu.h"
34-
using RNG = srngcpu ;
3534

3635
#include "storch.h"
3736
#include "scarrier.h"
@@ -79,7 +78,7 @@ inline NP* SGenerate::GeneratePhotons(const NP* gs_ )
7978
sphoton* pp = (sphoton*)ph->bytes() ;
8079

8180
unsigned rng_seed = 1u ;
82-
RNG rng ;
81+
srngcpu rng;
8382
rng.seed = rng_seed ;
8483

8584
for(int i=0 ; i < tot_photon ; i++ )
@@ -103,4 +102,3 @@ inline NP* SGenerate::GeneratePhotons(const NP* gs_ )
103102
return ph ;
104103
}
105104

106-

sysrap/sboundary.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ components, then combine the Fresnel transmission/reflection coefficients.
2121
2222
**/
2323

24+
#include "srng_traits.h"
2425

2526
struct sboundary
2627
{
@@ -67,11 +68,11 @@ struct sboundary
6768
float3 A_parallel ;
6869
float3 alt_pol ; // check an alternative polarization expression
6970

70-
sboundary(RNG& rng, sctx& ctx );
71+
template <typename Rng> sboundary(Rng &rng, sctx &ctx);
7172
};
7273

73-
inline sboundary::sboundary( RNG& rng, sctx& ctx )
74-
:
74+
template <typename Rng>
75+
inline sboundary::sboundary(Rng &rng, sctx &ctx) :
7576
p(ctx.p),
7677
s(ctx.s),
7778
n1(s.material1.x),
@@ -103,7 +104,7 @@ inline sboundary::sboundary( RNG& rng, sctx& ctx )
103104
TT(normalize(E2_t)),
104105
TransCoeff(tir || n1c1 == 0.f ? 0.f : n2c2*dot(E2_t,E2_t)/n1c1),
105106
ReflectCoeff(1.f - TransCoeff),
106-
u_reflect(curand_uniform(&rng)),
107+
u_reflect(srng_uniform(rng)),
107108
reflect(u_reflect > TransCoeff),
108109
flag(reflect ? BOUNDARY_REFLECT : BOUNDARY_TRANSMIT),
109110
Coeff(reflect ? ReflectCoeff : TransCoeff),

sysrap/scarrier.h

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@
1414
#include "scuda.h"
1515
#include "squad.h"
1616

17-
#if defined(__CUDACC__) || defined(__CUDABE__)
18-
#else
19-
#include "srngcpu.h"
20-
#endif
21-
22-
23-
2417
struct scarrier
2518
{
2619
quad q0 ;
@@ -30,10 +23,9 @@ struct scarrier
3023
quad q4 ;
3124
quad q5 ;
3225

33-
34-
SCARRIER_METHOD static void generate( sphoton& p, RNG& rng, const quad6& gs, unsigned long long photon_id, unsigned genstep_id );
35-
36-
26+
template <typename Rng>
27+
SCARRIER_METHOD static void generate(sphoton &p, Rng &rng, const quad6 &gs, unsigned long long photon_id,
28+
unsigned genstep_id);
3729

3830
#if defined(__CUDACC__) || defined(__CUDABE__) || defined(MOCK_CURAND) || defined(MOCK_CUDA)
3931
#else
@@ -42,10 +34,13 @@ struct scarrier
4234

4335
};
4436

45-
46-
47-
inline SCARRIER_METHOD void scarrier::generate( sphoton& p_, RNG& rng, const quad6& gs, unsigned long long photon_id, unsigned genstep_id ) // static
37+
template <typename Rng>
38+
inline SCARRIER_METHOD void scarrier::generate(sphoton &p_, Rng &rng, const quad6 &gs, unsigned long long photon_id,
39+
unsigned genstep_id)
4840
{
41+
(void)rng;
42+
(void)genstep_id;
43+
4944
quad4& p = (quad4&)p_ ;
5045

5146
p.q0.f = gs.q2.f ;
@@ -72,6 +67,3 @@ inline void scarrier::FillGenstep( scarrier& gs, int matline, int numphoton_per_
7267

7368
}
7469
#endif
75-
76-
77-

sysrap/scurand.h

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,31 @@
44
#define SCURAND_METHOD __device__
55
#include "curand_kernel.h"
66
#else
7-
#define SCURAND_METHOD
8-
#include "srngcpu.h"
9-
#endif
7+
#define SCURAND_METHOD
8+
#endif
9+
10+
#include "srng_traits.h"
1011

1112
template <typename T>
1213
struct scurand
1314
{
14-
static SCURAND_METHOD T uniform( RNG* rng );
15+
template <typename Rng> static SCURAND_METHOD T uniform(Rng *rng);
1516
};
1617

17-
18-
19-
template<> inline float scurand<float>::uniform( RNG* rng )
18+
template <> template <typename Rng> inline SCURAND_METHOD float scurand<float>::uniform(Rng *rng)
2019
{
2120
#ifdef FLIP_RANDOM
22-
return 1.f - curand_uniform(rng) ;
21+
return 1.f - srng_uniform(*rng);
2322
#else
24-
return curand_uniform(rng) ;
23+
return srng_uniform(*rng);
2524
#endif
2625
}
2726

28-
template<> inline double scurand<double>::uniform( RNG* rng )
27+
template <> template <typename Rng> inline SCURAND_METHOD double scurand<double>::uniform(Rng *rng)
2928
{
3029
#ifdef FLIP_RANDOM
31-
return 1. - curand_uniform_double(rng) ;
30+
return 1. - srng_uniform_double(*rng);
3231
#else
33-
return curand_uniform_double(rng) ;
32+
return srng_uniform_double(*rng);
3433
#endif
3534
}
36-
37-
38-

sysrap/srng.h

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,94 @@ So have to implement all methods in each specialization, or use a separate helpe
1717
1818
**/
1919

20+
#include "srng_traits.h"
2021
#include <curand_kernel.h>
2122

2223
using XORWOW = curandStateXORWOW ;
2324
using Philox = curandStatePhilox4_32_10 ;
2425

2526
#if defined(RNG_XORWOW)
26-
using RNG = XORWOW ;
27+
using DefaultDeviceRNG = XORWOW;
2728
#elif defined(RNG_PHILOX)
28-
using RNG = Philox ;
29+
using DefaultDeviceRNG = Philox;
2930
#endif
3031

32+
#if defined(RNG_XORWOW) || defined(RNG_PHILOX)
33+
#if !defined(MOCK_CURAND) && !defined(MOCK_CUDA)
34+
using RNG = DefaultDeviceRNG;
35+
#endif
36+
#endif
3137

3238
#if defined(__CUDACC__) || defined(__CUDABE__)
39+
40+
template <> struct srng<XORWOW>
41+
{
42+
static SRNG_METHOD float uniform(XORWOW &state)
43+
{
44+
return curand_uniform(&state);
45+
}
46+
static SRNG_METHOD double uniform_double(XORWOW &state)
47+
{
48+
return curand_uniform_double(&state);
49+
}
50+
};
51+
52+
template <> struct srng<Philox>
53+
{
54+
static SRNG_METHOD float uniform(Philox &state)
55+
{
56+
return curand_uniform(&state);
57+
}
58+
59+
static SRNG_METHOD double uniform_double(Philox &state)
60+
{
61+
return curand_uniform_double(&state);
62+
}
63+
};
64+
3365
#else
3466

3567
#include <cstring>
3668
#include <sstream>
3769
#include <string>
3870

39-
template<typename T> struct srng {};
40-
4171
// template specializations for the different states
4272
template<>
4373
struct srng<XORWOW>
4474
{
4575
static constexpr char CODE = 'X' ;
46-
static constexpr const char* NAME = "XORWOW" ;
47-
static constexpr unsigned SIZE = sizeof(XORWOW) ;
48-
static constexpr bool UPLOAD_RNG_STATES = true ;
76+
static constexpr const char *NAME = "XORWOW";
77+
static constexpr unsigned SIZE = sizeof(XORWOW);
78+
static constexpr bool UPLOAD_RNG_STATES = true;
79+
80+
static inline float uniform(XORWOW &state)
81+
{
82+
return curand_uniform(&state);
83+
}
84+
85+
static inline double uniform_double(XORWOW &state)
86+
{
87+
return curand_uniform_double(&state);
88+
}
4989
};
5090

5191
template<>
5292
struct srng<Philox>
5393
{
5494
static constexpr char CODE = 'P' ;
55-
static constexpr const char* NAME = "Philox" ;
56-
static constexpr unsigned SIZE = sizeof(Philox) ;
57-
static constexpr bool UPLOAD_RNG_STATES = false ;
95+
static constexpr const char *NAME = "Philox";
96+
static constexpr unsigned SIZE = sizeof(Philox);
97+
static constexpr bool UPLOAD_RNG_STATES = false;
98+
99+
static inline float uniform(Philox &state)
100+
{
101+
return curand_uniform(&state);
102+
}
103+
104+
static inline double uniform_double(Philox &state)
105+
{
106+
return curand_uniform_double(&state);
107+
}
58108
};
59109

60110
// helper function

sysrap/srng_traits.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
/**
3+
srng_traits.h : common RNG traits interface
4+
===========================================
5+
6+
Shared CPU/GPU generation headers should accept an RNG object explicitly
7+
instead of depending on an include-order-selected global RNG alias.
8+
9+
Concrete RNG headers specialize srng<T> and provide uniform accessors.
10+
**/
11+
12+
#if defined(__CUDACC__) || defined(__CUDABE__)
13+
#define SRNG_METHOD __device__
14+
#else
15+
#define SRNG_METHOD inline
16+
#endif
17+
18+
template <typename T> struct srng;
19+
20+
template <typename Rng> SRNG_METHOD float srng_uniform(Rng &rng)
21+
{
22+
return srng<Rng>::uniform(rng);
23+
}
24+
25+
template <typename Rng> SRNG_METHOD double srng_uniform_double(Rng &rng)
26+
{
27+
return srng<Rng>::uniform_double(rng);
28+
}

sysrap/srngcpu.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ To extend that see::
2424

2525
#include <random>
2626
#include "s_seq.h"
27+
#include "srng_traits.h"
2728

2829
struct srngcpu
2930
{
@@ -118,7 +119,19 @@ inline std::string srngcpu::demo(int n)
118119
inline float curand_uniform(srngcpu* state ){ return state->generate_float() ; }
119120
inline double curand_uniform_double(srngcpu* state ){ return state->generate_double() ; }
120121

121-
122-
123-
124-
122+
template <> struct srng<srngcpu>
123+
{
124+
static constexpr char CODE = 'C';
125+
static constexpr const char *NAME = "srngcpu";
126+
static constexpr unsigned SIZE = sizeof(srngcpu);
127+
static constexpr bool UPLOAD_RNG_STATES = false;
128+
129+
static inline float uniform(srngcpu &state)
130+
{
131+
return state.generate_float();
132+
}
133+
static inline double uniform_double(srngcpu &state)
134+
{
135+
return state.generate_double();
136+
}
137+
};

0 commit comments

Comments
 (0)