77#include " imgui.h"
88#include " polyscope/polyscope.h"
99
10+ #include " rxmesh/geometry_factory.h"
1011
1112using namespace rxmesh ;
1213
@@ -46,16 +47,16 @@ void __global__ solve_stretch(const Context context,
4647 const float dt2)
4748{
4849 auto solve = [&](const EdgeHandle& eh, const VertexIterator& iter) {
49- auto v0 = iter[0 ];
50- auto v1 = iter[1 ];
50+ auto v1 = iter[0 ];
51+ auto v2 = iter[1 ];
5152
52- const glm::fvec3 x0 = new_x.to_glm <3 >(v0);
5353 const glm::fvec3 x1 = new_x.to_glm <3 >(v1);
54+ const glm::fvec3 x2 = new_x.to_glm <3 >(v2);
5455
55- const float w1 (invM (v0 , 0 )), w2 (invM (v1 , 0 ));
56+ const float w1 (invM (v1 , 0 )), w2 (invM (v2 , 0 ));
5657
5758 if (w1 + w2 > 0 .f ) {
58- glm::fvec3 n = x0 - x1 ;
59+ glm::fvec3 n = x1 - x2 ;
5960 const float d = glm::length (n);
6061 glm::fvec3 dpp (0 .f , 0 .f , 0 .f );
6162 const float constraint = (d - rest_len (eh, 0 ));
@@ -81,8 +82,8 @@ void __global__ solve_stretch(const Context context,
8182 }
8283
8384 for (int i = 0 ; i < 3 ; ++i) {
84- ::atomicAdd (&dp (v0 , i), dpp[i] * w1);
85- ::atomicAdd (&dp (v1 , i), -(dpp[i] * w2));
85+ ::atomicAdd (&dp (v1 , i), dpp[i] * w1);
86+ ::atomicAdd (&dp (v2 , i), -(dpp[i] * w2));
8687 }
8788 }
8889 };
@@ -182,6 +183,145 @@ void __global__ solve_bending(const Context context,
182183 ShmemAllocator shrd_alloc;
183184 query.dispatch<Op::EVDiamond>(block, shrd_alloc, solve);
184185}
186+
187+ template <uint32_t blockThreads, bool XPBD>
188+ void __global__ solve_stretch_and_bending (const Context context,
189+ VertexAttribute<float > dp,
190+ EdgeAttribute<float > la_s,
191+ EdgeAttribute<float > la_b,
192+ const VertexAttribute<float > invM,
193+ const VertexAttribute<float > new_x,
194+ const EdgeAttribute<float > rest_len,
195+ const float stretch_compliance,
196+ const float stretch_relaxation,
197+ const float bending_compliance,
198+ const float bending_relaxation,
199+ const float dt2_inv)
200+ {
201+ auto solve = [&](const EdgeHandle& eh, const VertexIterator& iter) {
202+ // iter[0] and iter[2] are the edge two vertices
203+ // iter[1] and iter[3] are the two opposite vertices
204+
205+ auto v1 = iter[0 ];
206+ auto v2 = iter[2 ];
207+
208+ auto v3 = iter[1 ];
209+ auto v4 = iter[3 ];
210+
211+ float v1_st[3 ] = {0 , 0 , 0 };
212+ float v2_st[3 ] = {0 , 0 , 0 };
213+
214+ const float w1 (invM (v1, 0 )), w2 (invM (v2, 0 ));
215+
216+ const glm::fvec3 x1 = new_x.to_glm <3 >(v1);
217+ const glm::fvec3 x2 = new_x.to_glm <3 >(v2);
218+
219+ // stretch term (for v1 and v2)
220+ if (w1 + w2 > 0 .f ) {
221+ glm::fvec3 n = x1 - x2;
222+ const float d = glm::length (n);
223+ glm::fvec3 dpp (0 .f , 0 .f , 0 .f );
224+ const float constraint = (d - rest_len (eh, 0 ));
225+
226+ n = glm::normalize (n);
227+ if constexpr (XPBD) {
228+ const float compliance = stretch_compliance * dt2_inv;
229+
230+ const float d_lambda =
231+ -(constraint + compliance * la_s (eh, 0 )) /
232+ (w1 + w2 + compliance) * stretch_relaxation;
233+
234+ for (int i = 0 ; i < 3 ; ++i) {
235+ dpp[i] = d_lambda * n[i];
236+ }
237+ la_s (eh, 0 ) += d_lambda;
238+
239+ } else {
240+ for (int i = 0 ; i < 3 ; ++i) {
241+ dpp[i] =
242+ -constraint / (w1 + w2) * n[i] * stretch_relaxation;
243+ }
244+ }
245+
246+ for (int i = 0 ; i < 3 ; ++i) {
247+ v1_st[i] = dpp[i] * w1;
248+ v2_st[i] = -(dpp[i] * w2);
249+ }
250+ }
251+
252+ // bending term (for v1, v2, v3, and v4)
253+ if (v3.is_valid () && v4.is_valid ()) {
254+ const float w3 (invM (v3, 0 )), w4 (invM (v4, 0 ));
255+
256+ const glm::fvec3 x3 = new_x.to_glm <3 >(v3);
257+ const glm::fvec3 x4 = new_x.to_glm <3 >(v4);
258+
259+
260+ if (w1 + w2 + w3 + w4 > 0 .f ) {
261+
262+ glm::fvec3 p2 = x2 - x1;
263+ glm::fvec3 p3 = x3 - x1;
264+ glm::fvec3 p4 = x4 - x1;
265+
266+ float l23 = glm::length (glm::cross (p2, p3));
267+ float l24 = glm::length (glm::cross (p2, p4));
268+ if (l23 < 1e-8 ) {
269+ l23 = 1 .f ;
270+ }
271+ if (l24 < 1e-8 ) {
272+ l24 = 1 .f ;
273+ }
274+ glm::fvec3 n1 = glm::cross (p2, p3);
275+ n1 /= l23;
276+ glm::fvec3 n2 = glm::cross (p2, p4);
277+ n2 /= l24;
278+
279+ // clamp(dot(n1, n2), -1., 1.)
280+ float d = std::max (1 .f , std::min (dot (n1, n2), -1 .f ));
281+
282+ glm::fvec3 q3 = (cross (p2, n2) + cross (n1, p2) * d) / l23;
283+ glm::fvec3 q4 = (cross (p2, n1) + cross (n2, p2) * d) / l24;
284+ glm::fvec3 q2 = -(cross (p3, n2) + cross (n1, p3) * d) / l23 -
285+ (cross (p4, n1) + cross (n2, p4) * d) / l24;
286+ glm::fvec3 q1 = -q2 - q3 - q4;
287+
288+ float sum_wq = w1 * glm::length2 (q1) + w2 * glm::length2 (q2) +
289+ w3 * glm::length2 (q3) + w4 * glm::length2 (q4);
290+ float constraint = acos (d) - acos (-1 .);
291+
292+ if constexpr (XPBD) {
293+ float compliance = bending_compliance * dt2_inv;
294+ float d_lambda = -(constraint + compliance * la_b (eh, 0 )) /
295+ (sum_wq + compliance) * bending_relaxation;
296+
297+ constraint = sqrt (1 - d * d) * d_lambda;
298+ la_b (eh, 0 ) += d_lambda;
299+ } else {
300+ constraint = -sqrt (1 - d * d) * constraint /
301+ (sum_wq + 1e-7 ) * bending_relaxation;
302+ }
303+ for (int i = 0 ; i < 3 ; ++i) {
304+ ::atomicAdd (&dp (v1, i), w1 * constraint * q1[i] + v1_st[i]);
305+ ::atomicAdd (&dp (v2, i), w2 * constraint * q2[i] + v2_st[i]);
306+ ::atomicAdd (&dp (v3, i), w3 * constraint * q3[i]);
307+ ::atomicAdd (&dp (v4, i), w4 * constraint * q4[i]);
308+ }
309+ }
310+ } else {
311+ for (int i = 0 ; i < 3 ; ++i) {
312+ ::atomicAdd (&dp (v1, i), v1_st[i]);
313+ ::atomicAdd (&dp (v2, i), v2_st[i]);
314+ }
315+ }
316+ };
317+
318+ auto block = cooperative_groups::this_thread_block();
319+
320+ Query<blockThreads> query (context);
321+ ShmemAllocator shrd_alloc;
322+ query.dispatch<Op::EVDiamond>(block, shrd_alloc, solve);
323+ }
324+
185325int main (int argc, char ** argv)
186326{
187327 Log::init ();
@@ -197,11 +337,19 @@ int main(int argc, char** argv)
197337
198338 RXMeshStatic rx (STRINGIFY (INPUT_DIR) " cloth.obj" );
199339
340+ // std::vector<std::vector<float>> verts;
341+ // std::vector<std::vector<uint32_t>> fv;
342+ // const int nnn = 540;
343+ // const float dxx = 1.0f / float(nnn);
344+ // rxmesh::create_plane(verts, fv, nnn, nnn, 2, dxx);
345+ // RXMeshStatic rx(fv);
346+ // rx.add_vertex_coordinates(verts, "Coords");
347+
200348 // scale mesh info unit bounding box
201349 rx.scale ({0 .f , 0 .f , 0 .f }, {1 .f , 1 .f , 1 .f });
202350
203351
204- constexpr uint32_t blockThreads = 256 ;
352+ constexpr uint32_t blockThreads = 320 ;
205353
206354 // XPBD paramters
207355 const float frame_dt = 1e-2 ;
@@ -213,7 +361,7 @@ int main(int argc, char** argv)
213361 const float stretch_compliance = 1e-7 ;
214362 const float bending_compliance = 1e-6 ;
215363 const float mass = 1.0 ;
216- const bool XPBD = false ;
364+ constexpr bool XPBD = true ;
217365
218366 // fixtures paramters
219367 const glm::fvec4 fixure_spheres[4 ] = {{0 .f , 1 .f , 0 .f , 0.004 },
@@ -252,34 +400,39 @@ int main(int argc, char** argv)
252400 }
253401 });
254402
255- LaunchBox<blockThreads> init_edges_lb;
403+
404+ LaunchBox<blockThreads> solve_lb;
405+
256406 LaunchBox<blockThreads> solve_stretch_lb;
257407 LaunchBox<blockThreads> solve_bending_lb;
258408
259- rx.prepare_launch_box (
260- {Op::EV}, init_edges_lb, (void *)init_edges<blockThreads>);
261-
262409 rx.prepare_launch_box (
263410 {Op::EV}, solve_stretch_lb, (void *)solve_stretch<blockThreads>);
264411
265412 rx.prepare_launch_box (
266413 {Op::EVDiamond}, solve_bending_lb, (void *)solve_bending<blockThreads>);
267414
268- init_edges<blockThreads>
269- <<<init_edges_lb.blocks,
270- init_edges_lb.num_threads,
271- init_edges_lb.smem_bytes_dyn>>> (rx.get_context (), *x, *rest_len);
415+ rx.prepare_launch_box ({Op::EVDiamond},
416+ solve_lb,
417+ (void *)solve_stretch_and_bending<blockThreads, XPBD>);
418+
419+ // init edges
420+ rx.run_kernel <blockThreads>(
421+ {Op::EV}, init_edges<blockThreads>, *x, *rest_len);
272422
273- int frame = 0 ;
423+ int frame = 0 ;
424+ int max_frames = 100 ;
274425
275- bool test = true ;
426+ bool test = false ;
276427 float mean (0 .f );
277428 float mean2 (0 .f );
278429
279430
280431 // solve
281432 bool started = false ;
282433
434+ float total_time = 0 ;
435+
283436 auto polyscope_callback = [&]() mutable {
284437 if (ImGui::Button (" Start Simulation" ) || started) {
285438 started = true ;
@@ -319,35 +472,45 @@ int main(int argc, char** argv)
319472 dp->reset (0 , DEVICE);
320473
321474 // solveStretch
322- solve_stretch<blockThreads>
323- <<<solve_stretch_lb.blocks,
324- solve_stretch_lb.num_threads,
325- solve_stretch_lb.smem_bytes_dyn>>> (
326- rx.get_context (),
327- *dp,
328- *la_s,
329- *invM,
330- *new_x,
331- *rest_len,
332- XPBD,
333- stretch_compliance,
334- stretch_relaxation,
335- dt0 * dt0);
336-
337- // solveBending
338- solve_bending<blockThreads>
339- <<<solve_bending_lb.blocks,
340- solve_bending_lb.num_threads,
341- solve_bending_lb.smem_bytes_dyn>>> (
342- rx.get_context (),
343- *dp,
344- *la_b,
345- *invM,
346- *new_x,
347- XPBD,
348- bending_compliance,
349- bending_relaxation,
350- dt0 * dt0);
475+ // rx.run_kernel(solve_stretch_lb,
476+ // solve_stretch<blockThreads>,
477+ // *dp,
478+ // *la_s,
479+ // *invM,
480+ // *new_x,
481+ // *rest_len,
482+ // XPBD,
483+ // stretch_compliance,
484+ // stretch_relaxation,
485+ // dt0 * dt0);
486+ //
487+ //
488+ // // solveBending
489+ // rx.run_kernel(solve_bending_lb,
490+ // solve_bending<blockThreads>,
491+ // *dp,
492+ // *la_b,
493+ // *invM,
494+ // *new_x,
495+ // XPBD,
496+ // bending_compliance,
497+ // bending_relaxation,
498+ // dt0 * dt0);
499+
500+ // solve Stretch and bending
501+ rx.run_kernel (solve_lb,
502+ solve_stretch_and_bending<blockThreads, XPBD>,
503+ *dp,
504+ *la_b,
505+ *la_s,
506+ *invM,
507+ *new_x,
508+ *rest_len,
509+ stretch_compliance,
510+ stretch_relaxation,
511+ bending_compliance,
512+ bending_relaxation,
513+ 1 .0f / (dt0 * dt0));
351514
352515 // postSolve
353516 rx.for_each_vertex (
@@ -386,6 +549,8 @@ int main(int argc, char** argv)
386549 timer.stop ();
387550 RXMESH_INFO (
388551 " Frame {}, time= {}(ms)" , frame, timer.elapsed_millis ());
552+ total_time += timer.elapsed_millis ();
553+
389554#if USE_POLYSCOPE
390555 x->move (DEVICE, HOST);
391556 rx.get_polyscope_mesh ()->updateVertexPositions (*x);
@@ -403,9 +568,18 @@ int main(int argc, char** argv)
403568 mean2 /= (3 .f * rx.get_num_vertices ());
404569 }
405570 }
571+ // if (frame >= max_frames) {
572+ // RXMESH_INFO("fps = {}", (frame * 100.f) / total_time);
573+ // exit(0);
574+ // }
406575 }
407576 };
408577
578+ // started = true;
579+ // while (true) {
580+ // polyscope_callback();
581+ // }
582+
409583#if USE_POLYSCOPE
410584 polyscope::state::userCallback = polyscope_callback;
411585 polyscope::show ();
0 commit comments