Skip to content

Commit 209b711

Browse files
goodfeliTorax team
authored andcommitted
Fix bugs in Solver hash and eq.
Solver is used a static argument. Solver is polymorphic and the subclass identity is used as an implicit argument affecting the trace, so the subclass identity must be part of the key for the Jax cache. The Solver __hash__ and __eq__ methods are not hashing and comparing the class id. This CL fixes that by comparing a string representation of the class id, including the module path. PiperOrigin-RevId: 819912183
1 parent c54dfac commit 209b711

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torax/_src/solver/solver.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class Solver(abc.ABC):
36-
"""Solves for a single time steps update to State.
36+
"""Solves for a single time step's update to State.
3737
3838
Attributes:
3939
physics_models: Physics models.
@@ -45,11 +45,18 @@ def __init__(
4545
):
4646
self.physics_models = physics_models
4747

48+
def _class_id(self) -> str:
49+
return f'{self.__class__.__module__}.{self.__class__.__qualname__}'
50+
4851
def __hash__(self) -> int:
49-
return hash(self.physics_models)
52+
return hash((self._class_id(), self.physics_models))
5053

5154
def __eq__(self, other: typing_extensions.Self) -> bool:
52-
return self.physics_models == other.physics_models
55+
return (
56+
isinstance(other, Solver)
57+
and self._class_id() == other._class_id()
58+
and self.physics_models == other.physics_models
59+
)
5360

5461
@functools.partial(
5562
jax.jit,

0 commit comments

Comments
 (0)