Skip to content

Conversation

@eclipse1605
Copy link
Contributor

Closes #8049

@codecov
Copy link

codecov bot commented Jan 17, 2026

Codecov Report

❌ Patch coverage is 78.82353% with 18 lines in your changes missing coverage. Please review.
✅ Project coverage is 90.74%. Comparing base (c68c56e) to head (2fe8441).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pymc/logprob/switch.py 78.82% 18 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8058      +/-   ##
==========================================
- Coverage   91.42%   90.74%   -0.68%     
==========================================
  Files         117      121       +4     
  Lines       19154    19487     +333     
==========================================
+ Hits        17512    17684     +172     
- Misses       1642     1803     +161     
Files with missing lines Coverage Δ
pymc/logprob/switch.py 77.85% <78.82%> (-2.83%) ⬇️

... and 12 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hoping we can simplify the code a bit further. Also let's get rid of all these typing.cast. It's in almost every line. Just mark the file as failing mypy if it requires so much work.

return (True, False)
return None

return _direct_form() or _swapped_form()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this code?

x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = pt.switch(x > 0, x, scale * x)

if cond_variant == "x_gt_0":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not parametrize already with the objects you gonna need?


a = _extract_scale_from_measurable_mul(
cast(TensorVariable, neg_branch), cast(TensorVariable, x)
match = _match_scaled_switch_branches(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having to guess, you could normalize the switch so that it's always switch(cond, x, neg_branch). You can write switch(c, t, f) -> switch(~c, f, t).

Maybe this allows you to simplify more logic elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya makes sense, though we still can’t use cond directly inside the logprob, because cond depends on the latent x

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the logprob you should be able to assume you already have a canonical form, because that's how you made it in the rewrite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i think we can do this and it should simplify the logprob. we can canonicalize in the rewrite so downstream logprob can assume one shape.

maybe something like this (pseudocode):

# in measurable rewrite:
# detect either ordering + equivalent conditions
#   switch(cond(x ? 0),  x,   a*x)
#   switch(cond(x ? 0), a*x,  x)
# and also swapped comparisons like 0 < x, 0 >= x, ...

(x, ax, sem) = match_branches_and_condition(node)

# normalize to always: switch(cond_canon, x, ax)
if true_branch_is(ax):
    sem = negate(sem)          # because switch(c, ax, x) == switch(~c, x, ax)
cond_canon = cond_from_semantics(x, sem) 

return measurable_switch_non_overlapping(cond_canon, x, ax)

# then in logprob:
# assume inputs are already canonical: switch(cond_canon, x, a*x)
sem = parse_canonical_cond(cond_canon, x)
gate = value_based_gate(value, sem) 
return switch(gate, logp(x, value), logp(a*x, value)) + check(a > 0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let's see if it works

@eclipse1605
Copy link
Contributor Author

I'm hoping we can simplify the code a bit further. Also let's get rid of all these typing.cast. It's in almost every line. Just mark the file as failing mypy if it requires so much work.

sure, is this a general convention we follow throughout the codebase?

@ricardoV94
Copy link
Member

I'm hoping we can simplify the code a bit further. Also let's get rid of all these typing.cast. It's in almost every line. Just mark the file as failing mypy if it requires so much work.

sure, is this a general convention we follow throughout the codebase?

No, but in general I favor not making our code a mess because of mypy

@eclipse1605
Copy link
Contributor Author

anyways i misunderstood mypy failing locally in the earlier commit, it wasn't required

@eclipse1605
Copy link
Contributor Author

@ricardoV94 the failing test is flaky right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: extending equivalent non-overlapping switch logp forms

3 participants