Skip to content

Commit 16f3781

Browse files
mmccrackanMichael McCrackan
andauthored
Add fitting functionality using Ceres-Solver (#189)
Co-authored-by: Michael McCrackan <mmccrack@login33.chn.perlmutter.nersc.gov>
1 parent e9923bb commit 16f3781

11 files changed

Lines changed: 872 additions & 2 deletions

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ find_package(Spt3g REQUIRED)
3131
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
3232
find_package(FLAC)
3333
find_package(GSL)
34+
find_package(Ceres)
3435

3536
find_package(OpenMP)
3637
if(OPENMP_FOUND)
@@ -68,6 +69,7 @@ add_library(so3g SHARED
6869
src/so_linterp.cxx
6970
src/exceptions.cxx
7071
src/array_ops.cxx
72+
src/fitting_ops.cxx
7173
)
7274

7375
# We could disable the lib prefix on the output library... but let's not.
@@ -83,6 +85,8 @@ target_link_libraries(so3g PUBLIC spt3g::core)
8385
# Link to GSL
8486
target_include_directories(so3g PRIVATE ${GSL_INCLUDE_DIR})
8587
target_link_libraries(so3g PUBLIC ${GSL_LIBRARIES})
88+
# Link Ceres
89+
target_link_libraries(so3g PUBLIC Ceres::ceres Eigen3::Eigen)
8690

8791
# FLAC- library already comes from spt3g dependencies, but
8892
# we need to have the headers.

Dockerfile

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,31 @@ RUN apt update && apt install -y \
1414
gfortran \
1515
libopenblas-openmp-dev \
1616
libbz2-dev \
17-
python-is-python3
17+
python-is-python3 \
18+
libgoogle-glog-dev \
19+
libgflags-dev \
20+
libmetis-dev \
21+
libgtest-dev \
22+
libabsl-dev \
23+
libeigen3-dev
1824

1925
# Set the working directory
2026
WORKDIR /app_lib/so3g
2127

28+
# Fetch and install ceres-solver
29+
RUN git clone --depth 1 --branch 2.2.0 --recurse-submodules https://github.com/ceres-solver/ceres-solver
30+
31+
WORKDIR /app_lib/so3g/ceres-solver
32+
33+
RUN mkdir build \
34+
&& cd build \
35+
&& cmake .. -DBUILD_TESTING=OFF \
36+
&& make -j$(nproc) \
37+
&& make install
38+
39+
# Set the working directory back to so3g
40+
WORKDIR /app_lib/so3g
41+
2242
# Copy the current directory contents into the container
2343
ADD . /app_lib/so3g
2444

cmake/FindCeres.cmake

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# - Find the Ceres Solver library
2+
#
3+
# CERES_FOUND
4+
# CERES_INCLUDE_DIRS
5+
# CERES_LIBRARIES
6+
# CERES_LIBRARY_DIRS
7+
8+
# Look for the Ceres package
9+
find_path(CERES_INCLUDE_DIR NAMES ceres/ceres.h HINTS ENV CERES_DIR PATH_SUFFIXES include)
10+
find_library(CERES_LIBRARY NAMES ceres HINTS ENV CERES_DIR PATH_SUFFIXES lib)
11+
12+
# Get Dependencies
13+
find_package(Eigen3 REQUIRED)
14+
find_package(Glog REQUIRED)
15+
find_package(Gflags REQUIRED)
16+
17+
# Create the imported Ceres target
18+
add_library(Ceres::Ceres UNKNOWN IMPORTED)
19+
set_target_properties(Ceres::Ceres PROPERTIES
20+
IMPORTED_LOCATION "${CERES_LIBRARY}"
21+
INTERFACE_INCLUDE_DIRECTORIES "${CERES_INCLUDE_DIR};${EIGEN3_INCLUDE_DIR}"
22+
INTERFACE_LINK_LIBRARIES "${GLOG_LIBRARIES};${GFLAGS_LIBRARIES};${CERES_LIBRARY}"
23+
)
24+
25+
include (FindPackageHandleStandardArgs)
26+
find_package_handle_standard_args (Ceres DEFAULT_MSG CERES_LIBRARY CERES_INCLUDE_DIR)
27+
28+
# Set the results so they can be used by the project
29+
mark_as_advanced(CERES_INCLUDE_DIR CERES_LIBRARY)

cmake/FindEigen3.cmake

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# - Find the Eigen library
2+
#
3+
# EIGEN_FOUND
4+
# EIGEN_INCLUDE_DIRS
5+
# EIGEN_LIBRARIES
6+
# EIGEN_LIBRARY_DIRS
7+
8+
if (EIGEN_INCLUDE_DIR)
9+
# Already in cache, be silent
10+
set (EIGEN_FIND_QUIETLY TRUE)
11+
endif (EIGEN_INCLUDE_DIR)
12+
13+
find_path(EIGEN_INCLUDE_DIR "Eigen/Core"
14+
HINTS ENV EIGEN_DIR
15+
PATH_SUFFIXES eigen3
16+
)
17+
18+
add_library(Eigen3::Eigen INTERFACE IMPORTED)
19+
set_target_properties(Eigen3::Eigen PROPERTIES
20+
INTERFACE_INCLUDE_DIRECTORIES "${EIGEN_INCLUDE_DIR}")
21+
22+
include (FindPackageHandleStandardArgs)
23+
find_package_handle_standard_args (Eigen3 DEFAULT_MSG EIGEN_INCLUDE_DIR)
24+
25+
mark_as_advanced(EIGEN_INCLUDE_DIR)

cmake/FindGFlags.cmake

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# - Find GFLAGS
2+
#
3+
# GFLAGS_FOUND
4+
# GFLAGS_INCLUDE_DIRS
5+
# GFLAGS_LIBRARIES
6+
# GFLAGS_LIBRARY_DIRS
7+
8+
if (GFLAGS_INCLUDE_DIR)
9+
# Already in cache, be silent
10+
set (GFLAG_FIND_QUIETLY TRUE)
11+
endif (GFLAGS_INCLUDE_DIR)
12+
13+
find_path(GFLAGS_INCLUDE_DIR NAME gflags/gflags.h HINTS ENV GFLAGS_DIR PATH_SUFFIXES include)
14+
find_library(GFLAGS_LIBRARY NAME gflags HINTS ENV GFLAGS_DIR PATH_SUFFIXES lib)
15+
16+
set(GFLAGS_INCLUDE_DIRS ${GFLAGS_INCLUDE_DIR})
17+
set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY})
18+
19+
include (FindPackageHandleStandardArgs)
20+
find_package_handle_standard_args (Gflags DEFAULT_MSG GFLAGS_LIBRARIES GFLAGS_INCLUDE_DIRS)
21+
22+
mark_as_advanced(GFLAGS_LIBRARY_DEBUG GFLAGS_LIBRARY_RELEASE
23+
GFLAGS_LIBRARY GFLAGS_INCLUDE_DIR GFLAGS_ROOT_DIR)

cmake/FindGlog.cmake

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# - Find glog
2+
#
3+
# GLOG_FOUND
4+
# GLOG_INCLUDE_DIRS
5+
# GLOG_LIBRARIES
6+
# GLOG_LIBRARY_DIRS
7+
8+
if (GLOG_INCLUDE_DIR)
9+
# Already in cache, be silent
10+
set (GLOG_FIND_QUIETLY TRUE)
11+
endif (GLOG_INCLUDE_DIR)
12+
13+
find_path(GLOG_INCLUDE_DIR glog/logging.h HINTS ENV GLOG_DIR PATH_SUFFIXES include)
14+
find_library(GLOG_LIBRARY NAME glog HINTS ENV GLOG_DIR PATH_SUFFIXES lib)
15+
16+
set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR})
17+
set(GLOG_LIBRARIES ${GLOG_LIBRARY})
18+
19+
include (FindPackageHandleStandardArgs)
20+
find_package_handle_standard_args (Glog DEFAULT_MSG GLOG_LIBRARY GLOG_INCLUDE_DIR)
21+
22+
mark_as_advanced(GLOG_LIBRARY_DEBUG GLOG_LIBRARY_RELEASE
23+
GLOG_LIBRARY GLOG_INCLUDE_DIR GLOG_ROOT_DIR)

include/array_ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
int get_dtype(const bp::object &);
4+
5+
template <typename T>
6+
T _calculate_median(const T*, const int);

include/fitting_ops.h

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#pragma once
2+
3+
#include <ceres/ceres.h>
4+
5+
template <int Degree>
6+
struct PolynomialModel
7+
{
8+
// Ceres requires number of params
9+
// to be known at compile time
10+
static constexpr int nparams = Degree + 1;
11+
12+
template <typename T>
13+
static T eval(T x, const T* params)
14+
{
15+
const T p0 = params[0];
16+
T result = p0;
17+
for (int i = 1; i < nparams; ++i) {
18+
const T p = params[i];
19+
result += p * ceres::pow(x, T(i));
20+
}
21+
22+
return result;
23+
}
24+
// Not needed for least squares as ceres
25+
// supports boundaries
26+
template <typename T>
27+
static bool check_bounds(const T* params)
28+
{
29+
return true;
30+
}
31+
};
32+
33+
struct NoiseModel
34+
{
35+
// Ceres requires number of params
36+
// to be known at compile time
37+
static constexpr int nparams = 3;
38+
39+
template <typename T>
40+
static T eval(T f, const T* params)
41+
{
42+
const T fknee = params[0];
43+
const T w = params[1];
44+
const T alpha = params[2];
45+
46+
return w * (1.0 + ceres::pow(fknee / f, alpha));
47+
}
48+
49+
// Slightly hacky way of bounds checking but is
50+
// suggested by Ceres to ensure it never goes
51+
// out of bounds
52+
template <typename T>
53+
static bool check_bounds(const T* params)
54+
{
55+
const T w = params[1];
56+
if (w <= 0.0) {
57+
return false;
58+
}
59+
return true;
60+
}
61+
};
62+
63+
// Model independent cost function for least-squares fitting
64+
template <typename Model>
65+
struct CostFunction
66+
{
67+
using model = Model;
68+
69+
CostFunction(int n, const double* x_data, const double* y_data)
70+
: n_pts(n), x(x_data), y(y_data) {}
71+
72+
template <typename T>
73+
bool operator()(const T* const params, T* residual) const {
74+
for (int i = 0; i < n_pts; ++i) {
75+
T model = Model::eval(T(x[i]), params);
76+
residual[i] = T(y[i]) - model;
77+
}
78+
return true;
79+
}
80+
81+
static ceres::Problem create(const int n, const double* xx,
82+
const double* yy, double* p)
83+
{
84+
ceres::Problem problem;
85+
86+
problem.AddResidualBlock(
87+
new ceres::AutoDiffCostFunction<CostFunction<Model>,
88+
ceres::DYNAMIC, Model::nparams>(
89+
new CostFunction<Model>(n, xx, yy), n), nullptr, p);
90+
91+
return problem;
92+
}
93+
94+
private:
95+
const int n_pts;
96+
const double* x;
97+
const double* y;
98+
};
99+
100+
// Model independent Negative Log Likelihood for generalized
101+
// unconstrained minimization. This is to be used when data
102+
// has residuals that follow a chi^2(1) distribution.
103+
template <typename Model>
104+
struct NegLogLikelihood
105+
{
106+
using model = Model;
107+
108+
NegLogLikelihood(int n, const double* x_data, const double* y_data)
109+
: n_pts(n), x(x_data), y(y_data) {}
110+
111+
template <typename T>
112+
bool operator()(const T* const params, T* cost) const
113+
{
114+
// Check bounds (saves a lot of time)
115+
if (!model::check_bounds(params)) {
116+
return false;
117+
}
118+
119+
cost[0] = T(0.);
120+
for (int i = 0; i < n_pts; ++i) {
121+
T model = Model::eval(T(x[i]), params);
122+
cost[0] += ceres::log(model) + T(y[i]) / model;
123+
}
124+
125+
return true;
126+
}
127+
128+
static ceres::FirstOrderFunction* create(int n, const double* xx,
129+
const double* yy)
130+
{
131+
// Ceres takes ownership of pointers so no cleanup is required
132+
return new ceres::AutoDiffFirstOrderFunction<NegLogLikelihood<Model>,
133+
Model::nparams>(new NegLogLikelihood<Model>(n, xx, yy));
134+
}
135+
136+
private:
137+
const int n_pts;
138+
const double* x;
139+
const double* y;
140+
};

src/array_ops.cxx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extern "C" {
2727
#include "so3g_numpy.h"
2828
#include "numpy_assist.h"
2929
#include "Ranges.h"
30+
#include "array_ops.h"
3031

3132
// TODO: Generalize to double precision too.
3233
// This implements Jon's noise model for ACT. It takes in
@@ -1118,6 +1119,9 @@ T _calculate_median(const T* data, const int n)
11181119
return gsl_stats_median(data_copy.data(), 1, n);
11191120
}
11201121

1122+
template double _calculate_median<double>(const double* arr, int size);
1123+
template float _calculate_median<float>(const float* arr, int size);
1124+
11211125
template <typename T>
11221126
void _detrend(T* data, const int ndets, const int nsamps, const int row_stride,
11231127
const std::string & method, const int linear_ncount,
@@ -1256,7 +1260,6 @@ void detrend(bp::object & tod, const std::string & method, const int linear_ncou
12561260
}
12571261
}
12581262

1259-
12601263
PYBINDINGS("so3g")
12611264
{
12621265
bp::def("nmat_detvecs_apply", nmat_detvecs_apply);

0 commit comments

Comments
 (0)