Skip to content

Commit 66a2fcd

Browse files
committed
update
1 parent 31f8372 commit 66a2fcd

File tree

2 files changed

+147
-41
lines changed

2 files changed

+147
-41
lines changed

genesis/engine/solvers/rigid/constraint/solver.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
USE_LS_OPT = os.environ.get("GS_SOLVER_LS_OPT", "0") == "1"
2424

25-
# TODO: set always true for CI benchmark use
25+
# TODO: set always true for CI benchmark use.
2626
USE_LS_OPT = 1
2727

2828
IS_OLD_TORCH = tuple(map(int, torch.__version__.split(".")[:2])) < (2, 8)
@@ -2184,7 +2184,9 @@ def func_ls_init_and_eval_p0_opt(
21842184

21852185
# Compute quad for all constraints and write to global
21862186
for i_c in range(n_con):
2187-
qf_0 = constraint_state.efc_D[i_c, i_b] * (0.5 * constraint_state.Jaref[i_c, i_b] * constraint_state.Jaref[i_c, i_b])
2187+
qf_0 = constraint_state.efc_D[i_c, i_b] * (
2188+
0.5 * constraint_state.Jaref[i_c, i_b] * constraint_state.Jaref[i_c, i_b]
2189+
)
21882190
qf_1 = constraint_state.efc_D[i_c, i_b] * (constraint_state.jv[i_c, i_b] * constraint_state.Jaref[i_c, i_b])
21892191
qf_2 = constraint_state.efc_D[i_c, i_b] * (0.5 * constraint_state.jv[i_c, i_b] * constraint_state.jv[i_c, i_b])
21902192
constraint_state.quad[i_c, 0, i_b] = qf_0
@@ -2223,9 +2225,7 @@ def func_ls_init_and_eval_p0_opt(
22232225
qf_0 = linear_neg * f * (-0.5 * rf - constraint_state.Jaref[i_c, i_b]) + linear_pos * f * (
22242226
-0.5 * rf + constraint_state.Jaref[i_c, i_b]
22252227
)
2226-
qf_1 = linear_neg * (-f * constraint_state.jv[i_c, i_b]) + linear_pos * (
2227-
f * constraint_state.jv[i_c, i_b]
2228-
)
2228+
qf_1 = linear_neg * (-f * constraint_state.jv[i_c, i_b]) + linear_pos * (f * constraint_state.jv[i_c, i_b])
22292229
qf_2 = 0.0
22302230
tmp_0 = tmp_0 + qf_0
22312231
tmp_1 = tmp_1 + qf_1
@@ -2285,9 +2285,7 @@ def func_ls_point_fn_opt(
22852285
qf_0 = linear_neg * f * (-0.5 * rf - constraint_state.Jaref[i_c, i_b]) + linear_pos * f * (
22862286
-0.5 * rf + constraint_state.Jaref[i_c, i_b]
22872287
)
2288-
qf_1 = linear_neg * (-f * constraint_state.jv[i_c, i_b]) + linear_pos * (
2289-
f * constraint_state.jv[i_c, i_b]
2290-
)
2288+
qf_1 = linear_neg * (-f * constraint_state.jv[i_c, i_b]) + linear_pos * (f * constraint_state.jv[i_c, i_b])
22912289
qf_2 = 0.0
22922290
tmp_0 = tmp_0 + qf_0
22932291
tmp_1 = tmp_1 + qf_1
@@ -2669,17 +2667,35 @@ def func_linesearch_batch(
26692667
# Batch evaluate all 3 in one constraint loop
26702668
if ti.static(USE_LS_OPT):
26712669
(
2672-
_a0, c0, c0_d0, c0_d1,
2673-
_a1, c1, c1_d0, c1_d1,
2674-
_a2, c2, c2_d0, c2_d1,
2670+
_a0,
2671+
c0,
2672+
c0_d0,
2673+
c0_d1,
2674+
_a1,
2675+
c1,
2676+
c1_d0,
2677+
c1_d1,
2678+
_a2,
2679+
c2,
2680+
c2_d0,
2681+
c2_d1,
26752682
) = func_ls_point_fn_3alphas_opt(
26762683
i_b, alpha_0, alpha_1, alpha_2, constraint_state, rigid_global_info
26772684
)
26782685
else:
26792686
(
2680-
_a0, c0, c0_d0, c0_d1,
2681-
_a1, c1, c1_d0, c1_d1,
2682-
_a2, c2, c2_d0, c2_d1,
2687+
_a0,
2688+
c0,
2689+
c0_d0,
2690+
c0_d1,
2691+
_a1,
2692+
c1,
2693+
c1_d0,
2694+
c1_d1,
2695+
_a2,
2696+
c2,
2697+
c2_d0,
2698+
c2_d1,
26832699
) = func_ls_point_fn_3alphas(
26842700
i_b, alpha_0, alpha_1, alpha_2, constraint_state, rigid_global_info
26852701
)
@@ -2709,20 +2725,54 @@ def func_linesearch_batch(
27092725
done = True
27102726
else:
27112727
(
2712-
b1, p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, p1_next_alpha,
2728+
b1,
2729+
p1_alpha,
2730+
p1_cost,
2731+
p1_deriv_0,
2732+
p1_deriv_1,
2733+
p1_next_alpha,
27132734
) = update_bracket_no_eval_local(
2714-
p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1,
2715-
alpha_0, c0, c0_d0, c0_d1,
2716-
alpha_1, c1, c1_d0, c1_d1,
2717-
alpha_2, c2, c2_d0, c2_d1,
2735+
p1_alpha,
2736+
p1_cost,
2737+
p1_deriv_0,
2738+
p1_deriv_1,
2739+
alpha_0,
2740+
c0,
2741+
c0_d0,
2742+
c0_d1,
2743+
alpha_1,
2744+
c1,
2745+
c1_d0,
2746+
c1_d1,
2747+
alpha_2,
2748+
c2,
2749+
c2_d0,
2750+
c2_d1,
27182751
)
27192752
(
2720-
b2, p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1, p2_next_alpha,
2753+
b2,
2754+
p2_alpha,
2755+
p2_cost,
2756+
p2_deriv_0,
2757+
p2_deriv_1,
2758+
p2_next_alpha,
27212759
) = update_bracket_no_eval_local(
2722-
p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1,
2723-
alpha_0, c0, c0_d0, c0_d1,
2724-
alpha_1, c1, c1_d0, c1_d1,
2725-
alpha_2, c2, c2_d0, c2_d1,
2760+
p2_alpha,
2761+
p2_cost,
2762+
p2_deriv_0,
2763+
p2_deriv_1,
2764+
alpha_0,
2765+
c0,
2766+
c0_d0,
2767+
c0_d1,
2768+
alpha_1,
2769+
c1,
2770+
c1_d0,
2771+
c1_d1,
2772+
alpha_2,
2773+
c2,
2774+
c2_d0,
2775+
c2_d1,
27262776
)
27272777

27282778
if b1 == 0 and b2 == 0:

genesis/engine/solvers/rigid/constraint/solver_island.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,23 @@ def _func_ls_point_fn_3alphas(self, i_b, alpha_0, alpha_1, alpha_2):
702702

703703
@ti.func
704704
def update_bracket_no_eval(
705-
self, p_alpha, p_cost, p_deriv_0, p_deriv_1,
706-
c0_alpha, c0_cost, c0_d0, c0_d1,
707-
c1_alpha, c1_cost, c1_d0, c1_d1,
708-
c2_alpha, c2_cost, c2_d0, c2_d1,
705+
self,
706+
p_alpha,
707+
p_cost,
708+
p_deriv_0,
709+
p_deriv_1,
710+
c0_alpha,
711+
c0_cost,
712+
c0_d0,
713+
c0_d1,
714+
c1_alpha,
715+
c1_cost,
716+
c1_d0,
717+
c1_d1,
718+
c2_alpha,
719+
c2_cost,
720+
c2_d0,
721+
c2_d1,
709722
):
710723
"""Bracket update using local candidate values. No global memory access or _func_ls_point_fn call."""
711724
flag = 0
@@ -813,9 +826,18 @@ def _func_linesearch(self, island, i_b):
813826

814827
while self.ls_it[i_b] < self.ls_iterations:
815828
(
816-
_a0, c0, c0_d0, c0_d1,
817-
_a1, c1, c1_d0, c1_d1,
818-
_a2, c2, c2_d0, c2_d1,
829+
_a0,
830+
c0,
831+
c0_d0,
832+
c0_d1,
833+
_a1,
834+
c1,
835+
c1_d0,
836+
c1_d1,
837+
_a2,
838+
c2,
839+
c2_d0,
840+
c2_d1,
819841
) = self._func_ls_point_fn_3alphas(i_b, alpha_0, alpha_1, alpha_2)
820842

821843
p1_next_alpha = alpha_0
@@ -842,20 +864,54 @@ def _func_linesearch(self, island, i_b):
842864
done = True
843865
else:
844866
(
845-
b1, p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, p1_next_alpha,
867+
b1,
868+
p1_alpha,
869+
p1_cost,
870+
p1_deriv_0,
871+
p1_deriv_1,
872+
p1_next_alpha,
846873
) = self.update_bracket_no_eval(
847-
p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1,
848-
alpha_0, c0, c0_d0, c0_d1,
849-
alpha_1, c1, c1_d0, c1_d1,
850-
alpha_2, c2, c2_d0, c2_d1,
874+
p1_alpha,
875+
p1_cost,
876+
p1_deriv_0,
877+
p1_deriv_1,
878+
alpha_0,
879+
c0,
880+
c0_d0,
881+
c0_d1,
882+
alpha_1,
883+
c1,
884+
c1_d0,
885+
c1_d1,
886+
alpha_2,
887+
c2,
888+
c2_d0,
889+
c2_d1,
851890
)
852891
(
853-
b2, p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1, p2_next_alpha,
892+
b2,
893+
p2_alpha,
894+
p2_cost,
895+
p2_deriv_0,
896+
p2_deriv_1,
897+
p2_next_alpha,
854898
) = self.update_bracket_no_eval(
855-
p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1,
856-
alpha_0, c0, c0_d0, c0_d1,
857-
alpha_1, c1, c1_d0, c1_d1,
858-
alpha_2, c2, c2_d0, c2_d1,
899+
p2_alpha,
900+
p2_cost,
901+
p2_deriv_0,
902+
p2_deriv_1,
903+
alpha_0,
904+
c0,
905+
c0_d0,
906+
c0_d1,
907+
alpha_1,
908+
c1,
909+
c1_d0,
910+
c1_d1,
911+
alpha_2,
912+
c2,
913+
c2_d0,
914+
c2_d1,
859915
)
860916

861917
if b1 == 0 and b2 == 0:

0 commit comments

Comments
 (0)