Skip to content

Commit 640e830

Browse files
committed
adding optional access to the contact during line search
1 parent 11cb659 commit 640e830

File tree

8 files changed

+41
-63
lines changed

8 files changed

+41
-63
lines changed

apps/NeoHookean/barrier_energy.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ void add_contact(RXMeshStatic& rx,
8181
T d = (xi - x_dbc).dot(normal);
8282

8383
if (d < dhat) {
84-
// if (vh.local_id() == 20 || vh.local_id() == 21 || vh.local_id()
85-
// == 22 || vh.local_id() == 23 || vh.local_id() == 24) {
8684
bool inserted = contact_pairs.insert(vh, dbc_vertex);
8785
assert(inserted);
8886

@@ -91,19 +89,8 @@ void add_contact(RXMeshStatic& rx,
9189

9290
inserted = contact_pairs.insert(vh, dbc_vertex2);
9391
assert(inserted);
94-
95-
printf("\n addec contact pair between %d, (%d, %d, %d)",
96-
vh.local_id(),
97-
dbc_vertex.local_id(),
98-
dbc_vertex1.local_id(),
99-
dbc_vertex2.local_id());
10092
}
10193
});
102-
103-
{
104-
CUDA_ERROR(cudaGetLastError());
105-
CUDA_ERROR(cudaDeviceSynchronize());
106-
}
10794
}
10895

10996
template <typename ProblemT,

apps/NeoHookean/neo_hookean.cu

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,6 @@ void neo_hookean(RXMeshStatic& rx, T dx)
198198
timer.add("LinearSolver");
199199
timer.add("Diff");
200200

201-
//add_contact(
202-
// rx, problem.vv_pairs, v_dbc[0], v_dbc[1], v_dbc[2], is_dbc, x, dhat);
203-
//problem.update_hessian();
204-
//printf("\n");
205-
206201
auto step_forward = [&]() {
207202
// x_tilde = x + v*h
208203
timer.start("Step");
@@ -241,10 +236,6 @@ void neo_hookean(RXMeshStatic& rx, T dx)
241236
dhat);
242237
problem.update_hessian();
243238
problem.eval_terms();
244-
{
245-
CUDA_ERROR(cudaGetLastError());
246-
CUDA_ERROR(cudaDeviceSynchronize());
247-
}
248239

249240
grad.copy_from(problem.grad, DEVICE, DEVICE);
250241
// DBC satisfied
@@ -297,8 +288,17 @@ void neo_hookean(RXMeshStatic& rx, T dx)
297288
line_search_init_step = std::min(nh_step, bar_step);
298289

299290
// TODO: line search should pass the step to the friction energy
300-
bool ls_success =
301-
newton_solver.line_search(line_search_init_step, 0.5, 64, 0.0);
291+
bool ls_success = newton_solver.line_search(
292+
line_search_init_step, 0.5, 64, 0.0, [&](auto temp_x) {
293+
add_contact(rx,
294+
problem.vv_pairs,
295+
v_dbc[0],
296+
v_dbc[1],
297+
v_dbc[2],
298+
is_dbc,
299+
temp_x,
300+
dhat);
301+
});
302302

303303
if (!ls_success) {
304304
RXMESH_WARN("Line search failed!");
@@ -315,10 +315,6 @@ void neo_hookean(RXMeshStatic& rx, T dx)
315315
dhat);
316316
problem.update_hessian();
317317
problem.eval_terms();
318-
{
319-
CUDA_ERROR(cudaGetLastError());
320-
CUDA_ERROR(cudaDeviceSynchronize());
321-
}
322318

323319
T f = problem.get_current_loss();
324320

include/rxmesh/diff/candidate_pairs.h

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,12 @@ struct CandidatePairs
7373
/// if we have a space to insert new pairs
7474

7575
auto add_candidate_with_indices = [&](int id) {
76-
// add the handles
77-
m_pairs_handle(id).first = c0;
78-
m_pairs_handle(id).second = c1;
79-
8076
// add the indices
81-
id *= m_variable_dim * m_variable_dim * 2;
82-
8377
for (int i = 0; i < m_variable_dim; ++i) {
8478
for (int j = 0; j < m_variable_dim; ++j) {
8579

86-
int r_id = m_context.linear_id(c0) * m_variable_dim + i;
87-
int c_id = m_context.linear_id(c1) * m_variable_dim + j;
80+
int r_id = m_hess.get_row_id(c0) * m_variable_dim + i;
81+
int c_id = m_hess.get_row_id(c1) * m_variable_dim + j;
8882

8983
m_pairs_id(id, 0) = r_id;
9084
m_pairs_id(id, 1) = c_id;
@@ -107,13 +101,13 @@ struct CandidatePairs
107101

108102
#ifdef __CUDA_ARCH__
109103
int id = ::atomicAdd(m_current_num_pairs.data(DEVICE), 1);
110-
if (id < m_pairs_handle.rows()) {
111-
if (m_hess.is_non_zero(c0, c1)) {
104+
if (id < m_pairs_handle.rows()) {
105+
add_candidate(id);
106+
if (!m_hess.is_non_zero(c0, c1)) {
112107
add_candidate(id);
113-
} else {
114-
::atomicAdd(m_current_num_index.data(DEVICE),
115-
m_variable_dim * m_variable_dim * 2);
116-
add_candidate_with_indices(id);
108+
int idd = ::atomicAdd(m_current_num_index.data(DEVICE),
109+
m_variable_dim * m_variable_dim * 2);
110+
add_candidate_with_indices(idd);
117111
}
118112
return true;
119113
} else {
@@ -123,12 +117,12 @@ struct CandidatePairs
123117
#else
124118
if (m_current_num_pairs(0) < m_pairs_handle.rows()) {
125119
int id = m_current_num_pairs(0);
126-
m_current_num_pairs(0)++;
127-
if (m_hess.is_non_zero(c0, c1)) {
128-
add_candidate(id);
129-
} else {
120+
m_current_num_pairs(0)++;
121+
add_candidate(id);
122+
if (!m_hess.is_non_zero(c0, c1)) {
123+
int idd = m_current_num_index(0);
130124
m_current_num_index(0) += m_variable_dim * m_variable_dim * 2;
131-
add_candidate_with_indices(id);
125+
add_candidate_with_indices(idd);
132126
}
133127
return true;
134128
} else {

include/rxmesh/diff/diff_query_kernel.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ __global__ static void diff_kernel_active_pair(
399399

400400
std::pair<HandleT0, HandleT1> pair = pairs.get_pair(id);
401401

402+
assert(hess.is_non_zero(pair.first, pair.second));
403+
402404
PairIterator<HandleT0> iter(pair.first, pair.second);
403405

404406
ScalarT res = user_func(diff_handle, iter, objective);

include/rxmesh/diff/hessian_sparse_matrix.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ struct HessianSparseMatrix : public SparseMatrix<T>
4646
const IndexT local_i,
4747
const IndexT local_j) const
4848
{
49+
assert(is_non_zero(row_v, col_v));
50+
4951
const IndexT r_id =
5052
this->get_row_id(row_v) * this->m_replicate + local_i;
5153
const IndexT c_id =
@@ -63,6 +65,8 @@ struct HessianSparseMatrix : public SparseMatrix<T>
6365
const IndexT local_i,
6466
const IndexT local_j)
6567
{
68+
assert(is_non_zero(row_v, col_v));
69+
6670
const IndexT r_id =
6771
this->get_row_id(row_v) * this->m_replicate + local_i;
6872
const IndexT c_id =
@@ -81,6 +85,8 @@ struct HessianSparseMatrix : public SparseMatrix<T>
8185
const IndexT local_i,
8286
const IndexT local_j) const
8387
{
88+
assert(is_non_zero(row_v, col_v));
89+
8490
const IndexT r_id =
8591
this->get_row_id(row_v) * this->m_replicate + local_i;
8692
const IndexT c_id =

include/rxmesh/diff/lbfgs_solver.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,7 @@ struct LBFGSSolver
4949
y_list[i].reset(0, LOCATION_ALL);
5050
}
5151
}
52-
53-
inline void solve(cudaStream_t stream = NULL)
54-
{
55-
compute_direction(stream);
56-
line_search(stream);
57-
}
52+
5853

5954
inline void compute_direction(cudaStream_t stream = NULL)
6055
{

include/rxmesh/diff/newton_solver.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,7 @@ struct NetwtonSolver
5050
};
5151
}
5252
}
53-
54-
/**
55-
* @brief
56-
*/
57-
inline void solve(cudaStream_t stream = NULL)
58-
{
59-
compute_direction(stream);
60-
line_search(stream);
61-
}
53+
6254

6355
/**
6456
* @brief solve to get Newton direction
@@ -140,10 +132,13 @@ struct NetwtonSolver
140132
/**
141133
* @brief line search
142134
*/
135+
using Callback = std::function<void(Attribute<T, ObjHandleT>)>;
136+
143137
inline bool line_search(const T s_max = 1.0,
144138
const T shrink = 0.8,
145139
const int max_iters = 64,
146140
const T armijo_const = 1e-4,
141+
Callback cb = {},
147142
cudaStream_t stream = NULL)
148143
{
149144
// we are going to keep trying to update temp_objective until we reach
@@ -183,6 +178,9 @@ struct NetwtonSolver
183178

184179

185180
// eval new obj func
181+
if (cb) {
182+
cb(*temp_objective);
183+
}
186184
problem.eval_terms_passive(temp_objective.get(), stream);
187185

188186
// get the new value of the objective function

include/rxmesh/matrix/sparse_matrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ struct SparseMatrix
602602
* @brief check if the input entry (x,y) is a non-zero (i.e., if it
603603
* allocated in the CSR representation)
604604
*/
605-
__device__ __host__ bool is_non_zero(const IndexT x, IndexT y) const
605+
__device__ __host__ bool is_non_zero(const IndexT x, const IndexT y) const
606606
{
607607
const IndexT start = row_ptr()[x];
608608
const IndexT end = row_ptr()[x + 1];

0 commit comments

Comments
 (0)