In the current code convention across examples/ directory multiple files define functions as:-
def logistic_loss(params, feature_matrix, labels):
Upon adding the respective Type hints this function would now be modified to:-
def logistic_loss(
params: Mapping[str, Any],
feature_matrix: Any, # PyTree of arrays
labels: jax.Array,
) -> jax.Array:
Why this variant is helpful as Type hints prevent pytype errors and JAX tracing ambiguity. Making the example rudimentary to opt with JAX core library components.