Skip to content

fix derived tag check rules#192

Merged
patrick-kidger merged 3 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/fixtagchecks
Jan 31, 2026
Merged

fix derived tag check rules#192
patrick-kidger merged 3 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/fixtagchecks

Conversation

@jpbrodrick89
Copy link
Contributor

Just noticed that a few of these were incorrect for composed/derived linear operator, think all fixed as best we can for now.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, thank you! Don't know how these slipped in 😅 the fixes all look correct to me. I have some minor comments but that's it.

return check(operator.operator)


def _scalar_sign(scalar) -> int | None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An enum.Enum might be a bit neater here?



def _scalar_sign(scalar) -> int | None:
"""Returns scalar sign if known at trace time otherwise None."""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that the behaviour will change depending on whether or not we are inside JIT, which I think is undesirable. I think it'd probably be better to just always treat JAX arrays as unknown? The JAX-array-outside-JIT case is fairly edge-case anyway since so little computation ever happens outside JIT.

So I think we'd end up replacing the try-except with a isinstance(scalar, (int, float, np.ndarray, np.generic)).

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work too, but behaviour may still differ in or outside of jit if jax.jit is used instead eqx.filter_jit is used which could convert all these type to tracers? If you'd rather have it always unknown I can sympathise with that too.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay without the jit/filter_jit equivalency here, since that's already true across a whole host of cases – and making that possible is really the point of filter_jit in the first place.

(Coming at this a few years later, I really do not like that we had to introduce filter_* though. Sometimes I wonder if we should have mandated trees with all-arrays and tackled this point with static fields instead. That ship has very definitely sailed, however.)

@jpbrodrick89
Copy link
Contributor Author

Another idea for has_unit_diagonal is you could undo the scalar multiplication in a clever way in Triangular so that has_unit_diagonal simply means that the BASE operator has_unit_diagonal so once the scalar is removed standard solve optimisations can continue? Maybe a bit confusing, but potentially worth it?

@patrick-kidger
Copy link
Owner

Another idea for has_unit_diagonal is you could undo the scalar multiplication in a clever way in Triangular so that has_unit_diagonal simply means that the BASE operator has_unit_diagonal so once the scalar is removed standard solve optimisations can continue? Maybe a bit confusing, but potentially worth it?

Haha! This is a good observation. I think the correct separation of concerns here would be for has_unit_diagonal should stay the same, and then for Triangular to special-case whichever derived operators, call has_unit_diagonal on their wrapped operators, and then do whatever logic happens on top of that.

It's fairly edge-case so that comes under "happy to take a PR on that" if you feel strongly!

@patrick-kidger patrick-kidger merged commit a61917e into patrick-kidger:main Jan 31, 2026
1 check failed
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 31, 2026

LGTM, merged! Thank you for the fix. 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants