Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scenic/model_lib/base_models/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
})


def classification_metrics_function(
def classification_metrics_function( # pytype: disable=annotation-type-mismatch
logits: jnp.ndarray,
batch: base_model.Batch,
target_is_onehot: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion scenic/model_lib/base_models/encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def num_tokens(logits: jnp.ndarray,
_MAX_PERPLEXITY = 1.0e4


def encoder_decoder_metrics_function(
def encoder_decoder_metrics_function( # pytype: disable=annotation-type-mismatch
logits: jnp.ndarray,
batch: base_model.Batch,
target_is_onehot: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def multilabel_classification_metrics_function(
logits: jnp.ndarray,
batch: base_model.Batch,
target_is_multihot: bool = False,
metrics: base_model.MetricNormalizerFnDict = _MULTI_LABEL_CLASSIFICATION_METRICS,
metrics: base_model.MetricNormalizerFnDict = _MULTI_LABEL_CLASSIFICATION_METRICS, # pytype: disable=annotation-type-mismatch
axis_name: Union[str, Tuple[str, ...]] = 'batch',
) -> Dict[str, Tuple[float, int]]:
"""Calculates metrics for the multi-label classification task.
Expand Down
2 changes: 1 addition & 1 deletion scenic/model_lib/base_models/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
})


def regression_metrics_function(
def regression_metrics_function( # pytype: disable=annotation-type-mismatch
predictions: jnp.ndarray,
batch: base_model.Batch,
metrics: base_model.MetricNormalizerFnDict = _REGRESSION_METRICS,
Expand Down
2 changes: 1 addition & 1 deletion scenic/model_lib/base_models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def num_pixels(logits: jnp.ndarray,
})


def semantic_segmentation_metrics_function(
def semantic_segmentation_metrics_function( # pytype: disable=annotation-type-mismatch
logits: jnp.ndarray,
batch: base_model.Batch,
target_is_onehot: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion scenic/projects/av_mae/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_spectogram_targets(inputs: jnp.ndarray,
return patched_input


def feature_regression_metrics_function(
def feature_regression_metrics_function( # pytype: disable=annotation-type-mismatch
predictions: jnp.ndarray,
prediction_masks: jnp.ndarray,
batch: base_model.Batch,
Expand Down
6 changes: 3 additions & 3 deletions scenic/projects/polyvit/polyvit_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Modality:
})


def bow_classification_metrics_function(
def bow_classification_metrics_function( # pytype: disable=annotation-type-mismatch
logits: jnp.ndarray,
batch: base_model.Batch,
target_is_multihot: bool = False,
Expand Down Expand Up @@ -120,7 +120,7 @@ def bow_classification_metrics_function(
return evaluated_metrics # pytype: disable=bad-return-type # jax-types


def multihead_classification_metrics_function(
def multihead_classification_metrics_function( # pytype: disable=annotation-type-mismatch
logits,
batch,
metrics: base_model.MetricNormalizerFnDict = _MULTIHEADLABEL_METRICS,
Expand Down Expand Up @@ -183,7 +183,7 @@ def multihead_classification_metrics_function(

def classification_metrics_function_with_acc_top_5(*args, **kwargs):
"""A wrapper over classification_metrics_function which has accuracy_top_5."""
return classification_metrics_function(
return classification_metrics_function( # pytype: disable=wrong-arg-types
*args, metrics=_MULTIHEADLABEL_METRICS, **kwargs)


Expand Down