Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.

Commit 26529f8

Browse files
himaghnafacebook-github-bot
authored andcommitted
Only store Sigma values in BART samples instead of object (#1527)
Summary: Pull Request resolved: #1527 Background: We are building Bayesian Additive Regression Trees (BART) as an experimental causal inference model in beanmachine. Details of the project can be found in https://docs.google.com/document/d/11nkB6UTGpvQBEC2yBjfgwAr8VabTlD7R9XufGQG0EvI/edit?usp=sharing and the proposed design can be found in the draft design document: https://docs.google.com/document/d/1o3J7yobDF0M9E27Y0tP2889fycmemXUZbHE5cebRqzs/edit?usp=sharing. In this diff: The noise standard deviation (sigma) parameter is never really used in the prediction tasks. While we would like to retain them for diagnostic purposes, there is no reason to store the NoiseStandardDeviation object in the sample trace. In this diff, we are modifying the BART class to only store float samples of the noise standard deviation. Reviewed By: feynmanliang Differential Revision: D37635208 fbshipit-source-id: be2f53b61b666fe9d50d2504a57351c91bd24915
1 parent 9ba385f commit 26529f8

File tree

1 file changed

+1
-1
lines changed
  • src/beanmachine/ppl/experimental/causal_inference/models/bart

1 file changed

+1
-1
lines changed

src/beanmachine/ppl/experimental/causal_inference/models/bart/bart_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _step(self) -> Tuple[List, float]:
219219
self.X
220220
)
221221
self._update_sigma(self.y - self._predict_step())
222-
return self.all_trees, self.sigma
222+
return self.all_trees, self.sigma.val
223223

224224
def _update_leaf_mean(self, tree: Tree, partial_residual: torch.Tensor):
225225
"""

0 commit comments

Comments
 (0)