Skip to content

Commit 142a7c7

Browse files
committed
Fix type annotatoions
1 parent f01018b commit 142a7c7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

gmmx/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def log_prob(self, x: jax.Array, means: jax.Array) -> jax.Array:
426426
)
427427

428428
@classmethod
429-
def from_precisions(cls, precisions: jax.AnyArray) -> DiagCovariances:
429+
def from_precisions(cls, precisions: AnyArray) -> DiagCovariances:
430430
"""Create covariance matrix from precision matrices"""
431431
values = 1.0 / precisions
432432
return cls.from_squeezed(values=values)
@@ -964,7 +964,7 @@ def _initialize_gmm(self, x: AnyArray) -> None:
964964
self._gmm = GaussianMixtureModelJax.from_squeezed(
965965
means=self.means_init, # type: ignore [arg-type]
966966
covariances=covar.from_precisions(
967-
self.precisions_init.astype(np.float32)
967+
self.precisions_init.astype(np.float32) # type: ignore [union-attr]
968968
).values_numpy,
969969
weights=self.weights_init, # type: ignore [arg-type]
970970
covariance_type=self.covariance_type,

0 commit comments

Comments
 (0)