-
Notifications
You must be signed in to change notification settings - Fork 2.2k
logprob for non-overlapping switch accepts equivalent spellings #8058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
logprob for non-overlapping switch accepts equivalent spellings #8058
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8058 +/- ##
==========================================
- Coverage 91.42% 90.74% -0.68%
==========================================
Files 117 121 +4
Lines 19154 19487 +333
==========================================
+ Hits 17512 17684 +172
- Misses 1642 1803 +161
🚀 New features to boost your workflow:
|
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm hoping we can simplify the code a bit further. Also let's get rid of all these typing.cast. It's in almost every line. Just mark the file as failing mypy if it requires so much work.
pymc/logprob/switch.py
Outdated
| return (True, False) | ||
| return None | ||
|
|
||
| return _direct_form() or _swapped_form() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify this code?
tests/logprob/test_switch.py
Outdated
| x = pm.Normal.dist(mu=0, sigma=1, size=(3,)) | ||
| y = pt.switch(x > 0, x, scale * x) | ||
|
|
||
| if cond_variant == "x_gt_0": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not parametrize already with the objects you gonna need?
pymc/logprob/switch.py
Outdated
|
|
||
| a = _extract_scale_from_measurable_mul( | ||
| cast(TensorVariable, neg_branch), cast(TensorVariable, x) | ||
| match = _match_scaled_switch_branches( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having to guess, you could normalize the switch so that it's always switch(cond, x, neg_branch). You can write switch(c, t, f) -> switch(~c, f, t).
Maybe this allows you to simplify more logic elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya makes sense, though we still can’t use cond directly inside the logprob, because cond depends on the latent x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the logprob you should be able to assume you already have a canonical form, because that's how you made it in the rewrite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i think we can do this and it should simplify the logprob. we can canonicalize in the rewrite so downstream logprob can assume one shape.
maybe something like this (pseudocode):
# in measurable rewrite:
# detect either ordering + equivalent conditions
# switch(cond(x ? 0), x, a*x)
# switch(cond(x ? 0), a*x, x)
# and also swapped comparisons like 0 < x, 0 >= x, ...
(x, ax, sem) = match_branches_and_condition(node)
# normalize to always: switch(cond_canon, x, ax)
if true_branch_is(ax):
sem = negate(sem) # because switch(c, ax, x) == switch(~c, x, ax)
cond_canon = cond_from_semantics(x, sem)
return measurable_switch_non_overlapping(cond_canon, x, ax)
# then in logprob:
# assume inputs are already canonical: switch(cond_canon, x, a*x)
sem = parse_canonical_cond(cond_canon, x)
gate = value_based_gate(value, sem)
return switch(gate, logp(x, value), logp(a*x, value)) + check(a > 0)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, let's see if it works
sure, is this a general convention we follow throughout the codebase? |
No, but in general I favor not making our code a mess because of mypy |
|
anyways i misunderstood mypy failing locally in the earlier commit, it wasn't required |
|
@ricardoV94 the failing test is flaky right? |
Closes #8049