Skip to content

Conversation

@eclipse1605
Copy link
Contributor

Description

fixed a failure in derived scan logprob construction when the observed/value tensor provides more static broadcastability information than the generative scan graph (e.g. observed has a size-1 axis like (date, 1) while the scan state was inferred as non broadcastable on that axis).

in this, model.logp() could fail during the measurable scan rewrite with a scan outputs_info broadcast pattern mismatch (scan output inferred as matrix like vs. outputs_info expecting vector-like).

  • the scan logprob rewrite propagate/align broadcastability metadata introduced by observed/value variables so the rewritten scan is internally consistent.
  • this happens without turning scan inner placeholders into non nominal Apply nodes (so scan reconstruction remains valid).

note

I think the same idea can be generalized by treating static broadcastability metadata as part of the measurable scan rewrite contract:

  • identify the “outer” value/observed variables that participate in the measurable rewrite and extract their broadcastable axes (ignoring the time axis for sequences).
  • when creating inner scan placeholders / outputs_info proxies for the logprob rewritten scan, ensure their TensorType.shape reflects any size-1/broadcastable axes implied by the outer variables.
  • apply the same normalization to tapped init buffers (pt.join / outputs_info) so that init and scan outputs agree on broadcastability, without inserting broadcast Apply nodes into the inner graph (placeholders must remain nominal vars).

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 9, 2026

I'm inclined to fix this on the Scan side in PyTensor. Basically if inp.type.filter_variable(out.type) passes we should be good to go. This would mean that shape=(None,) -> (5,) or (None,)->(1,) is fair game, just not (5,)->(2,) or something that changes number of dims.

Furthemore shape=(5)->(None,) would end up with a specify_shape in the output and therefore shape=(5,)->(5,). This is the same rule we use for rewrites.

Otherwise it's too tricky to work with scans programatically, and we end up with failures/awkward work-arounds as demonstrated in this PR

@eclipse1605
Copy link
Contributor Author

sorry for the slow reply, am i getting this right, implementation wise? we’d update PyTensor so that:

  • in check_broadcast, we only raise if the inner expects broadcastable axes that the outer doesn’t have, but we allow the outer to be more broadcastable/specific than the inner.
  • in Scan.validate_inner_graph, we consider input/output types compatible if the input type is a supertype (same dtype/ndim, output can be more specific), rather than requiring exact broadcastable equality.

@eclipse1605
Copy link
Contributor Author

@ricardoV94 does this sound right?

@ricardoV94
Copy link
Member

The inner should have as much information as the outer, the outre can have less. But yes. Let's open a PR and see how it goes

@eclipse1605
Copy link
Contributor Author

so in pytensor/pytensor/scan/op.py

we change if b1 != b2 to if b1 and not b2

and in validate_inner_graph instead of

if (
    type_input.dtype != type_output.dtype
    or type_input.broadcastable != type_output.broadcastable
):

we replace the broadcastable eq check with a compatibility check using TensorType.is_super()

if type_input.dtype != type_output.dtype:
    # dtype error
elif isinstance(type_input, TensorType) and isinstance(type_output, TensorType):
    if not type_input.is_super(type_output):
        # compatibility error using is_super
else:
    if type_input != type_output:
        # fallback error

@ricardoV94
Copy link
Member

I think you can use is_super always, but let's try

@eclipse1605
Copy link
Contributor Author

@ricardoV94 i have openned a PR (pymc-devs/pytensor#1861)

@eclipse1605
Copy link
Contributor Author

@ricardoV94 should i close this PR?

@ricardoV94
Copy link
Member

No, but it should be changed to just have a regression test for the original issue to make sure it's working once we fix it in PyTensor

@eclipse1605
Copy link
Contributor Author

yeah makes sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Derived scan logprob fails when observed data provides more broadcastable information than generative graph

2 participants