Skip to content

Commit 4a6dcd3

Browse files
use loadAs to get particles as user provided type with overloaded operators
1 parent 5e4673d commit 4a6dcd3

File tree

1 file changed

+65
-16
lines changed

1 file changed

+65
-16
lines changed

examples/nbody/nbody.cpp

+65-16
Original file line numberDiff line numberDiff line change
@@ -38,33 +38,82 @@ using namespace llama::literals;
3838

3939
namespace usellama
4040
{
41-
struct Particle
41+
struct Vec
4242
{
43-
struct Pos
43+
FP x;
44+
FP y;
45+
FP z;
46+
47+
auto operator*=(FP s) -> Vec&
4448
{
45-
float x;
46-
float y;
47-
float z;
48-
} pos;
49-
struct Vel
49+
x *= s;
50+
y *= s;
51+
z *= s;
52+
return *this;
53+
}
54+
55+
auto operator*=(Vec v) -> Vec&
5056
{
51-
float x;
52-
float y;
53-
float z;
54-
} vel;
55-
float mass;
57+
x *= v.x;
58+
y *= v.y;
59+
z *= v.z;
60+
return *this;
61+
}
62+
63+
auto operator+=(Vec v) -> Vec&
64+
{
65+
x += v.x;
66+
y += v.y;
67+
z += v.z;
68+
return *this;
69+
}
70+
71+
auto operator-=(Vec v) -> Vec&
72+
{
73+
x -= v.x;
74+
y -= v.y;
75+
z -= v.z;
76+
return *this;
77+
}
78+
79+
friend auto operator+(Vec a, Vec b) -> Vec
80+
{
81+
return a += b;
82+
}
83+
84+
friend auto operator-(Vec a, Vec b) -> Vec
85+
{
86+
return a -= b;
87+
}
88+
89+
friend auto operator*(Vec a, FP s) -> Vec
90+
{
91+
return a *= s;
92+
}
93+
94+
friend auto operator*(Vec a, Vec b) -> Vec
95+
{
96+
return a *= b;
97+
}
98+
};
99+
100+
struct Particle
101+
{
102+
Vec pos;
103+
Vec vel;
104+
FP mass;
56105
};
57106

58107
template <typename VirtualParticleI, typename VirtualParticleJ>
59108
LLAMA_FN_HOST_ACC_INLINE void pPInteraction(VirtualParticleI&& pi, VirtualParticleJ pj)
60109
{
61-
auto dist = pi(0_DC) - pj(0_DC);
110+
auto dist = pi.template loadAs<Particle>().pos - pj.template loadAs<Particle>().pos;
62111
dist *= dist;
63-
const FP distSqr = EPS2 + dist(0_DC) + dist(1_DC) + dist(2_DC);
112+
const FP distSqr = EPS2 + dist.x + dist.y + dist.z;
64113
const FP distSixth = distSqr * distSqr * distSqr;
65114
const FP invDistCube = 1.0f / std::sqrt(distSixth);
66-
const FP sts = pj(2_DC) * invDistCube * TIMESTEP;
67-
pi(1_DC) += dist * sts;
115+
const FP sts = pj.template loadAs<Particle>().mass * invDistCube * TIMESTEP;
116+
pi(1_DC).store(pi.template loadAs<Particle>().vel + dist * sts);
68117
}
69118

70119
template <bool UseAccumulator, typename View>

0 commit comments

Comments
 (0)