|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#define BUILD_CUDA_MODULE |
| 4 | + |
| 5 | +#include <iostream> |
| 6 | +#include <memory> |
| 7 | +#include <random> |
| 8 | + |
| 9 | +#include <Eigen/Dense> |
| 10 | +#include <gsl/gsl_randist.h> |
| 11 | + |
| 12 | +#include <sogmm_open3d/SOGMMGPU.h> |
| 13 | + |
| 14 | +#define LOG_2_M_PI 1.83788 |
| 15 | + |
| 16 | +namespace o3d = open3d; |
| 17 | + |
| 18 | +namespace sogmm |
| 19 | +{ |
| 20 | + namespace gpu |
| 21 | + { |
| 22 | + template <typename T, uint32_t D> |
| 23 | + class EM |
| 24 | + { |
| 25 | + public: |
| 26 | + static constexpr uint32_t C = D * D; |
| 27 | + |
| 28 | + using Ptr = std::shared_ptr<EM<T, D>>; |
| 29 | + using ConstPtr = std::shared_ptr<const EM<T, D>>; |
| 30 | + |
| 31 | + using Container = SOGMM<T, D>; |
| 32 | + |
| 33 | + using Tensor = typename Container::Tensor; |
| 34 | + using SizeVector = typename Container::SizeVector; |
| 35 | + using Device = typename Container::Device; |
| 36 | + using Dtype = typename Container::Dtype; |
| 37 | + |
| 38 | + EM() |
| 39 | + { |
| 40 | + tol_ = 1e-3; |
| 41 | + reg_covar_ = 1e-6; |
| 42 | + max_iter_ = 100; |
| 43 | + |
| 44 | + device_ = Device("CUDA:0"); |
| 45 | + dtype_ = Dtype::template FromType<T>(); |
| 46 | + } |
| 47 | + |
| 48 | + ~EM() |
| 49 | + { |
| 50 | + } |
| 51 | + |
| 52 | + void initialize(const unsigned int &n_samples, const unsigned int &n_components) |
| 53 | + { |
| 54 | + unsigned int N = n_samples; |
| 55 | + unsigned int K = n_components; |
| 56 | + |
| 57 | + amax__ = Tensor::Zeros({N, 1, 1}, dtype_, device_); |
| 58 | + Log_Det_Cholesky__ = Tensor::Zeros({K, 1}, dtype_, device_); |
| 59 | + Log_Det_Cholesky_Tmp__ = Tensor::Zeros({K, D}, dtype_, device_); |
| 60 | + eDiff__ = Tensor::Zeros({N, K, D, 1}, dtype_, device_); |
| 61 | + Log_Prob_Norm__ = Tensor::Zeros({N, 1, 1}, dtype_, device_); |
| 62 | + |
| 63 | + Nk__ = Tensor::Zeros({1, K}, dtype_, device_); |
| 64 | + mDiff__ = Tensor::Zeros({K, D, N}, dtype_, device_); |
| 65 | + Log_Resp__ = Tensor::Zeros({N, K, 1}, dtype_, device_); |
| 66 | + } |
| 67 | + |
| 68 | + void logSumExp(const Tensor &in, const unsigned int dim, Tensor &out) |
| 69 | + { |
| 70 | + amax__ = in.Max({dim}, true); |
| 71 | + ((in - amax__).Exp()).Sum_({dim}, true, out); |
| 72 | + out.Log_(); |
| 73 | + out.Add_(amax__); |
| 74 | + } |
| 75 | + |
| 76 | + void estimateWeightedLogProb(const Tensor &Xt, const Container &sogmm, |
| 77 | + Tensor &Weighted_Log_Prob) |
| 78 | + { |
| 79 | + // Term 2 of Equation (3.14) |
| 80 | + // Log_Det_Cholesky__ = ((GetDiagonal(Precs_Chol_[0]).Log()).Sum({ 1 }, true)) |
| 81 | + // .Reshape({ 1, n_components_, 1 }); |
| 82 | + |
| 83 | + SizeVector Xt_shape = Xt.GetShape(); |
| 84 | + unsigned int N = Xt_shape[0]; |
| 85 | + unsigned int K = sogmm.n_components_; |
| 86 | + |
| 87 | + Log_Det_Cholesky__.template Fill<float>(0.0); |
| 88 | + Tensor Log_Det_Cholesky_View = (GetDiagonal(sogmm.precisions_cholesky_[0])); |
| 89 | + Log_Det_Cholesky_Tmp__.CopyFrom(Log_Det_Cholesky_View); |
| 90 | + Log_Det_Cholesky_Tmp__.Log_(); |
| 91 | + Log_Det_Cholesky_Tmp__.Sum_({1}, true, Log_Det_Cholesky__); |
| 92 | + |
| 93 | + // Diff, PDiff, Log_Gaussian_Prob are all terms for the first |
| 94 | + // term in Equation (3.14) |
| 95 | + eDiff__ = (Xt.Reshape({N, 1, D}) - sogmm.means_).Reshape({N, K, D, 1}); |
| 96 | + |
| 97 | + eDiff__ = (sogmm.precisions_cholesky_.Mul(eDiff__)).Sum({2}, true).Reshape({N, K, D, 1}); |
| 98 | + |
| 99 | + // Equation (3.14), output of estimateLogGaussianProb |
| 100 | + (eDiff__.Mul(eDiff__)).Sum_({2}, false, Weighted_Log_Prob); |
| 101 | + Weighted_Log_Prob.Add_(D * LOG_2_M_PI); |
| 102 | + Weighted_Log_Prob.Mul_(-0.5); |
| 103 | + Weighted_Log_Prob.Add_(Log_Det_Cholesky__.Reshape({1, K, 1})); |
| 104 | + |
| 105 | + // This is the first two terms in Equation (3.7) |
| 106 | + Weighted_Log_Prob.Add_(sogmm.weights_.Log()); |
| 107 | + } |
| 108 | + |
| 109 | + void eStep(const Tensor &Xt, const Container &sogmm) |
| 110 | + { |
| 111 | + unsigned int K = sogmm.n_components_; |
| 112 | + |
| 113 | + estimateWeightedLogProb(Xt, sogmm, Log_Resp__); |
| 114 | + |
| 115 | + Log_Prob_Norm__.template Fill<float>(0.0); |
| 116 | + logSumExp(Log_Resp__, 1, Log_Prob_Norm__); |
| 117 | + |
| 118 | + Log_Resp__.Sub_(Log_Prob_Norm__); |
| 119 | + |
| 120 | + likelihood_ = Log_Prob_Norm__.Mean({0, 1, 2}, false).template Item<T>(); |
| 121 | + } |
| 122 | + |
| 123 | + void mStep(const Tensor &Xt, const Tensor &Respt, Container &sogmm) |
| 124 | + { |
| 125 | + SizeVector Xt_shape = Xt.GetShape(); |
| 126 | + unsigned int N = Xt_shape[0]; |
| 127 | + unsigned int K = sogmm.n_components_; |
| 128 | + |
| 129 | + // initialize tensor for weights |
| 130 | + Nk__.template Fill<float>(0.0); |
| 131 | + Respt.Sum_({0}, true, Nk__); |
| 132 | + sogmm.weights_ = Nk__.T(); |
| 133 | + sogmm.weights_.Add_(Tensor::Ones({K, 1}, dtype_, device_) * 10 * |
| 134 | + std::numeric_limits<T>::epsilon()); |
| 135 | + |
| 136 | + // update means |
| 137 | + sogmm.means_ = (Respt.T().Matmul(Xt)).Div(sogmm.weights_).Reshape(sogmm.means_.GetShape()); |
| 138 | + |
| 139 | + // update covariances |
| 140 | + mDiff__ = (Xt.Reshape({N, 1, D}) - sogmm.means_).AsStrided({K, D, N}, {D, 1, D * K}); |
| 141 | + sogmm.covariances_ = mDiff__.MatmulBatched(mDiff__.Transpose(1, 2) * |
| 142 | + Respt.AsStrided({K, N, 1}, {1, K, 1})) |
| 143 | + .Reshape(sogmm.covariances_.GetShape()); |
| 144 | + sogmm.covariances_.Div_(sogmm.weights_.Reshape({1, K, 1, 1})); |
| 145 | + |
| 146 | + // add reg_covar_ along the diagonal for Covariances_ |
| 147 | + sogmm.covariances_[0].AsStrided({K, D}, {C, D + 1}).Add_(reg_covar_); |
| 148 | + |
| 149 | + // update weights |
| 150 | + sogmm.weights_.Div_(N); |
| 151 | + sogmm.weights_.Div_(sogmm.weights_.Sum({0}, false)); |
| 152 | + |
| 153 | + // update precision and covariance cholesky |
| 154 | + sogmm.updateCholesky(); |
| 155 | + } |
| 156 | + |
| 157 | + bool fit(const Tensor &Xt, const Tensor &Respt, Container &sogmm) |
| 158 | + { |
| 159 | + unsigned int K = sogmm.n_components_; |
| 160 | + |
| 161 | + SizeVector Xt_shape = Xt.GetShape(); |
| 162 | + unsigned int N = Xt_shape[0]; |
| 163 | + |
| 164 | + if (N <= 1) |
| 165 | + { |
| 166 | + throw std::runtime_error("fit: number of samples should be greater than 1."); |
| 167 | + } |
| 168 | + |
| 169 | + if (K <= 0) |
| 170 | + { |
| 171 | + throw std::runtime_error("fit: number of components should be greater than 0."); |
| 172 | + } |
| 173 | + |
| 174 | + if (N < K) |
| 175 | + { |
| 176 | + throw std::runtime_error("fit: number of components is " + |
| 177 | + std::to_string(K) + |
| 178 | + ". It should be strictly smaller than the " |
| 179 | + "number of points: " + |
| 180 | + std::to_string(N)); |
| 181 | + } |
| 182 | + |
| 183 | + initialize(N, K); |
| 184 | + |
| 185 | + mStep(Xt, Respt, sogmm); |
| 186 | + |
| 187 | + T lower_bound = -std::numeric_limits<T>::infinity(); |
| 188 | + for (unsigned int n_iter = 0; n_iter <= max_iter_; n_iter++) |
| 189 | + { |
| 190 | + T prev_lower_bound = lower_bound; |
| 191 | + |
| 192 | + // E step |
| 193 | + eStep(Xt, sogmm); |
| 194 | + |
| 195 | + // M step |
| 196 | + mStep(Xt, Log_Resp__.Exp().Reshape({N, K}), sogmm); |
| 197 | + |
| 198 | + // convergence check |
| 199 | + lower_bound = likelihood_; |
| 200 | + T change = lower_bound - prev_lower_bound; |
| 201 | + if (!std::isinf(change) && std::abs(change) < tol_) |
| 202 | + { |
| 203 | + converged_ = true; |
| 204 | + break; |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + // o3d::core::MemoryManagerCached::ReleaseCache(device_); |
| 209 | + |
| 210 | + if (converged_) |
| 211 | + { |
| 212 | + return true; |
| 213 | + } |
| 214 | + else |
| 215 | + { |
| 216 | + return false; |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + static void scoreSamples(const Tensor &Xt, const Container &sogmm, |
| 221 | + Tensor &Per_Sample_Log_Likelihood) |
| 222 | + { |
| 223 | + SizeVector Xt_shape = Xt.GetShape(); |
| 224 | + unsigned int N = Xt_shape[0]; |
| 225 | + unsigned int K = sogmm.n_components_; |
| 226 | + |
| 227 | + Tensor Weighted_Log_Prob = Tensor::Zeros({N, K, 1}, sogmm.dtype_, sogmm.device_); |
| 228 | + Tensor eDiff = Tensor::Zeros({N, K, D, 1}, sogmm.dtype_, sogmm.device_); |
| 229 | + Tensor Log_Det_Cholesky = Tensor::Zeros({K, 1}, sogmm.dtype_, sogmm.device_); |
| 230 | + Tensor Log_Det_Cholesky_Tmp = Tensor::Zeros({K, D}, sogmm.dtype_, sogmm.device_); |
| 231 | + |
| 232 | + Tensor Log_Det_Cholesky_View = (GetDiagonal(sogmm.precisions_cholesky_[0])); |
| 233 | + |
| 234 | + Log_Det_Cholesky.template Fill<float>(0.0); |
| 235 | + Log_Det_Cholesky_Tmp.CopyFrom(Log_Det_Cholesky_View); |
| 236 | + Log_Det_Cholesky_Tmp.Log_(); |
| 237 | + Log_Det_Cholesky_Tmp.Sum_({1}, true, Log_Det_Cholesky); |
| 238 | + |
| 239 | + eDiff = (Xt.Reshape({N, 1, D}) - sogmm.means_).Reshape({N, K, D, 1}); |
| 240 | + eDiff = (sogmm.precisions_cholesky_.Mul(eDiff)).Sum({2}, true).Reshape({N, K, D, 1}); |
| 241 | + |
| 242 | + (eDiff.Mul(eDiff)).Sum_({2}, false, Weighted_Log_Prob); |
| 243 | + Weighted_Log_Prob.Add_(D * LOG_2_M_PI); |
| 244 | + Weighted_Log_Prob.Mul_(-0.5); |
| 245 | + Weighted_Log_Prob.Add_(Log_Det_Cholesky.Reshape({1, K, 1})); |
| 246 | + |
| 247 | + Weighted_Log_Prob.Add_(sogmm.weights_.Log()); |
| 248 | + |
| 249 | + Tensor amax = Tensor::Zeros({N, 1, 1}, sogmm.dtype_, sogmm.device_); |
| 250 | + |
| 251 | + amax = Weighted_Log_Prob.Max({1}, true); |
| 252 | + ((Weighted_Log_Prob - amax).Exp()).Sum_({1}, true, Per_Sample_Log_Likelihood); |
| 253 | + Per_Sample_Log_Likelihood.Log_(); |
| 254 | + Per_Sample_Log_Likelihood.Add_(amax); |
| 255 | + } |
| 256 | + |
| 257 | + bool converged_ = false; |
| 258 | + |
| 259 | + Device device_; |
| 260 | + Dtype dtype_; |
| 261 | + |
| 262 | + T tol_; |
| 263 | + T reg_covar_; |
| 264 | + unsigned int max_iter_; |
| 265 | + T likelihood_; |
| 266 | + |
| 267 | + private: |
| 268 | + Tensor amax__; |
| 269 | + Tensor Log_Det_Cholesky__; |
| 270 | + Tensor Log_Det_Cholesky_Tmp__; |
| 271 | + Tensor eDiff__; |
| 272 | + Tensor Log_Prob_Norm__; |
| 273 | + |
| 274 | + Tensor Nk__; |
| 275 | + Tensor mDiff__; |
| 276 | + Tensor Log_Resp__; |
| 277 | + }; |
| 278 | + } |
| 279 | +} |
0 commit comments