diff --git a/src/stream_mapper/pytorch/builtin/_truncskewnorm.py b/src/stream_mapper/pytorch/builtin/_truncskewnorm.py index 46243a6..2bddca6 100644 --- a/src/stream_mapper/pytorch/builtin/_truncskewnorm.py +++ b/src/stream_mapper/pytorch/builtin/_truncskewnorm.py @@ -69,9 +69,10 @@ def ln_likelihood( cns, cens = self.coord_names, self.coord_err_names x = data[cns].array # (N, F) - _0 = self.xp.zeros_like(x)[None, ...] # (1, N, F) - a, b = _0 + self.xp.asarray([self.coord_bounds[k] for k in cns]).T[:, None, :] + a_, b_ = self._get_lower_upper_bound(data[self.indep_coord_names].array) + a, b = a_[idx], b_[idx] + # a, b = _0 + self.xp.asarray([self.coord_bounds[k] for k in cns]).T[:, None, :] mu = self._stack_param(mpars, "mu", cns)[idx] ln_s = self._stack_param(mpars, "ln-sigma", cns)[idx] skew = self._stack_param(mpars, "skew", cns)[idx] @@ -88,7 +89,7 @@ def ln_likelihood( # Find where -inf with xp.no_grad(): _lpdf = truncskewnorm_logpdf( - x[idx], loc=mu, ln_sigma=ln_s, skew=skew, a=a[idx], b=b[idx], xp=self.xp + x[idx], loc=mu, ln_sigma=ln_s, skew=skew, a=a, b=b, xp=self.xp ) fnt = xp.isfinite(_lpdf) # apply to X[idx] only @@ -99,18 +100,15 @@ def ln_likelihood( loc=mu[fnt], ln_sigma=ln_s[fnt], skew=skew[fnt], - a=a[idx][fnt], - b=b[idx][fnt], + a=a[fnt], + b=b[fnt], xp=self.xp, ) # Compute normal where SN is infinite. # Subtract 100 b/c that's where the SN logpdf drops to -inf n_lnpdf = ( - truncnorm_logpdf( - x[idx], loc=mu, ln_sigma=ln_s, a=a[idx], b=b[idx], xp=self.xp - ) - - 100 + truncnorm_logpdf(x[idx], loc=mu, ln_sigma=ln_s, a=a, b=b, xp=self.xp) - 100 ) idxlnliks = xp.where(fnt, sn_lnpdf, n_lnpdf)