@@ -167,6 +167,39 @@ def measure_ACC(mu, y, extract_label_indx=True): ## measures/calculates accuracy
167167 acc = jnp .sum ( jnp .equal (guess , lab ) )/ (y .shape [0 ] * 1. )
168168 return acc
169169
170+ @partial (jit , static_argnums = [3 ])
171+ def measure_BIC (X , n_model_params , max_model_score , is_log = True ):
172+ """
173+ Measures the Bayesian information criterion (BIC) with respect to the final
174+ score obtained by the model on a given dataset.
175+
176+ | BIC = -2 ln(L) + K * ln(N);
177+ | where N is number of data-points/rows of design matrix X,
178+ | K is total number parameters of the model of interest, and
179+ | L is the max/best-found value of a likelihood-like score L of the model
180+
181+ Args:
182+ X: dataset/design matrix that a model was fit to (max-likelihood optimized)
183+
184+ n_model_params: total number of model parameters (int)
185+
186+ max_model_score: max likelihood-like score obtained by model on X
187+
188+ is_log: is supplied `max_model_score` a log-likelihood? if this is False,
189+ this metric will apply a natural logarithm of the score (Default: True)
190+
191+ Returns:
192+ scalar for the Bayesian information criterion score
193+ """
194+ ## BIC = K * ln(N) - 2 ln(L)
195+ L_hat = max_model_score ## model's likelihood-like score (at max point)
196+ K = n_model_params ## number of model params
197+ N = X .shape [0 ] ## number of data-points
198+ if not is_log :
199+ L_hat = jnp .log (L_hat ) ## get log likelihood
200+ bic = - L_hat * 2. + jnp .log (N * 1. ) * K
201+ return bic
202+
170203@partial (jit , static_argnums = [2 ])
171204def measure_KLD (p_xHat , p_x , preserve_batch = False ):
172205 """
0 commit comments