From 38b001e514eb550ebd4e953c3efdc3b02ad3fdf4 Mon Sep 17 00:00:00 2001 From: Amer Elsheikh Date: Wed, 18 Feb 2026 14:54:49 -0800 Subject: [PATCH] Ignore some pytype errors. PiperOrigin-RevId: 872055299 --- scenic/model_lib/base_models/classification_model.py | 2 +- scenic/model_lib/base_models/encoder_decoder_model.py | 2 +- .../base_models/multilabel_classification_model.py | 2 +- scenic/model_lib/base_models/regression_model.py | 2 +- scenic/model_lib/base_models/segmentation_model.py | 2 +- scenic/projects/av_mae/base_model.py | 2 +- scenic/projects/polyvit/polyvit_base_model.py | 6 +++--- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/scenic/model_lib/base_models/classification_model.py b/scenic/model_lib/base_models/classification_model.py index 7829909a..00e82f65 100644 --- a/scenic/model_lib/base_models/classification_model.py +++ b/scenic/model_lib/base_models/classification_model.py @@ -33,7 +33,7 @@ }) -def classification_metrics_function( +def classification_metrics_function( # pytype: disable=annotation-type-mismatch logits: jnp.ndarray, batch: base_model.Batch, target_is_onehot: bool = False, diff --git a/scenic/model_lib/base_models/encoder_decoder_model.py b/scenic/model_lib/base_models/encoder_decoder_model.py index 15ae6858..57b595fa 100644 --- a/scenic/model_lib/base_models/encoder_decoder_model.py +++ b/scenic/model_lib/base_models/encoder_decoder_model.py @@ -61,7 +61,7 @@ def num_tokens(logits: jnp.ndarray, _MAX_PERPLEXITY = 1.0e4 -def encoder_decoder_metrics_function( +def encoder_decoder_metrics_function( # pytype: disable=annotation-type-mismatch logits: jnp.ndarray, batch: base_model.Batch, target_is_onehot: bool = False, diff --git a/scenic/model_lib/base_models/multilabel_classification_model.py b/scenic/model_lib/base_models/multilabel_classification_model.py index 61c74c1c..fc964d10 100644 --- a/scenic/model_lib/base_models/multilabel_classification_model.py +++ b/scenic/model_lib/base_models/multilabel_classification_model.py @@ -37,7 +37,7 @@ def multilabel_classification_metrics_function( logits: jnp.ndarray, batch: base_model.Batch, target_is_multihot: bool = False, - metrics: base_model.MetricNormalizerFnDict = _MULTI_LABEL_CLASSIFICATION_METRICS, + metrics: base_model.MetricNormalizerFnDict = _MULTI_LABEL_CLASSIFICATION_METRICS, # pytype: disable=annotation-type-mismatch axis_name: Union[str, Tuple[str, ...]] = 'batch', ) -> Dict[str, Tuple[float, int]]: """Calculates metrics for the multi-label classification task. diff --git a/scenic/model_lib/base_models/regression_model.py b/scenic/model_lib/base_models/regression_model.py index 0633851c..f3a1d6be 100644 --- a/scenic/model_lib/base_models/regression_model.py +++ b/scenic/model_lib/base_models/regression_model.py @@ -29,7 +29,7 @@ }) -def regression_metrics_function( +def regression_metrics_function( # pytype: disable=annotation-type-mismatch predictions: jnp.ndarray, batch: base_model.Batch, metrics: base_model.MetricNormalizerFnDict = _REGRESSION_METRICS, diff --git a/scenic/model_lib/base_models/segmentation_model.py b/scenic/model_lib/base_models/segmentation_model.py index feb58a01..fcc67585 100644 --- a/scenic/model_lib/base_models/segmentation_model.py +++ b/scenic/model_lib/base_models/segmentation_model.py @@ -61,7 +61,7 @@ def num_pixels(logits: jnp.ndarray, }) -def semantic_segmentation_metrics_function( +def semantic_segmentation_metrics_function( # pytype: disable=annotation-type-mismatch logits: jnp.ndarray, batch: base_model.Batch, target_is_onehot: bool = False, diff --git a/scenic/projects/av_mae/base_model.py b/scenic/projects/av_mae/base_model.py index 322b6ae4..69f86ce5 100644 --- a/scenic/projects/av_mae/base_model.py +++ b/scenic/projects/av_mae/base_model.py @@ -194,7 +194,7 @@ def get_spectogram_targets(inputs: jnp.ndarray, return patched_input -def feature_regression_metrics_function( +def feature_regression_metrics_function( # pytype: disable=annotation-type-mismatch predictions: jnp.ndarray, prediction_masks: jnp.ndarray, batch: base_model.Batch, diff --git a/scenic/projects/polyvit/polyvit_base_model.py b/scenic/projects/polyvit/polyvit_base_model.py index f45bac80..7a5b0b8b 100644 --- a/scenic/projects/polyvit/polyvit_base_model.py +++ b/scenic/projects/polyvit/polyvit_base_model.py @@ -66,7 +66,7 @@ class Modality: }) -def bow_classification_metrics_function( +def bow_classification_metrics_function( # pytype: disable=annotation-type-mismatch logits: jnp.ndarray, batch: base_model.Batch, target_is_multihot: bool = False, @@ -120,7 +120,7 @@ def bow_classification_metrics_function( return evaluated_metrics # pytype: disable=bad-return-type # jax-types -def multihead_classification_metrics_function( +def multihead_classification_metrics_function( # pytype: disable=annotation-type-mismatch logits, batch, metrics: base_model.MetricNormalizerFnDict = _MULTIHEADLABEL_METRICS, @@ -183,7 +183,7 @@ def multihead_classification_metrics_function( def classification_metrics_function_with_acc_top_5(*args, **kwargs): """A wrapper over classification_metrics_function which has accuracy_top_5.""" - return classification_metrics_function( + return classification_metrics_function( # pytype: disable=wrong-arg-types *args, metrics=_MULTIHEADLABEL_METRICS, **kwargs)