-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix derived scan logprob when observed provides more broadcastable information #8016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I'm inclined to fix this on the Scan side in PyTensor. Basically if Furthemore Otherwise it's too tricky to work with scans programatically, and we end up with failures/awkward work-arounds as demonstrated in this PR |
|
sorry for the slow reply, am i getting this right, implementation wise? we’d update PyTensor so that:
|
|
@ricardoV94 does this sound right? |
|
The inner should have as much information as the outer, the outre can have less. But yes. Let's open a PR and see how it goes |
|
so in we change and in if (
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):we replace the broadcastable eq check with a compatibility check using if type_input.dtype != type_output.dtype:
# dtype error
elif isinstance(type_input, TensorType) and isinstance(type_output, TensorType):
if not type_input.is_super(type_output):
# compatibility error using is_super
else:
if type_input != type_output:
# fallback error |
|
I think you can use |
|
@ricardoV94 i have openned a PR (pymc-devs/pytensor#1861) |
|
@ricardoV94 should i close this PR? |
|
No, but it should be changed to just have a regression test for the original issue to make sure it's working once we fix it in PyTensor |
|
yeah makes sense |
Description
fixed a failure in derived scan logprob construction when the observed/value tensor provides more static broadcastability information than the generative scan graph (e.g. observed has a size-1 axis like
(date, 1)while the scan state was inferred as non broadcastable on that axis).in this,
model.logp()could fail during the measurable scan rewrite with a scanoutputs_infobroadcast pattern mismatch (scan output inferred as matrix like vs.outputs_infoexpecting vector-like).Applynodes (so scan reconstruction remains valid).note
I think the same idea can be generalized by treating static broadcastability metadata as part of the measurable scan rewrite contract:
outputs_infoproxies for the logprob rewritten scan, ensure theirTensorType.shapereflects any size-1/broadcastable axes implied by the outer variables.pt.join/outputs_info) so that init and scan outputs agree on broadcastability, without inserting broadcastApplynodes into the inner graph (placeholders must remain nominal vars).Related Issue
Checklist
Type of change