1515from .. import settings
1616from ..distributions import MultitaskMultivariateNormal , MultivariateNormal
1717from ..likelihoods import _GaussianLikelihoodBase
18- from ..utils .generic import length_safe_zip
1918from ..utils .warnings import GPInputWarning
2019from .exact_prediction_strategies import prediction_strategy
2120from .gp import GP
@@ -128,7 +127,7 @@ def set_train_data(
128127 inputs = (inputs ,)
129128 inputs = tuple (input_ .unsqueeze (- 1 ) if input_ .ndimension () == 1 else input_ for input_ in inputs )
130129 if strict :
131- for input_ , t_input in length_safe_zip (inputs , self .train_inputs or (None ,)):
130+ for input_ , t_input in zip (inputs , self .train_inputs or (None ,), strict = True ):
132131 for attr in {"shape" , "dtype" , "device" }:
133132 expected_attr = getattr (t_input , attr , None )
134133 found_attr = getattr (input_ , attr , None )
@@ -221,7 +220,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
221220 [train_input , input .expand (input_batch_shape + input .shape [- 2 :])],
222221 dim = - 2 ,
223222 )
224- for train_input , input in length_safe_zip (train_inputs , inputs )
223+ for train_input , input in zip (train_inputs , inputs , strict = True )
225224 ]
226225 full_targets = torch .cat (
227226 [train_targets , targets .expand (target_batch_shape + targets .shape [data_dim_start :])], dim = data_dim_start
@@ -276,7 +275,7 @@ def __call__(self, *args, **kwargs):
276275 )
277276 if settings .debug .on ():
278277 if not all (
279- torch .equal (train_input , input ) for train_input , input in length_safe_zip (train_inputs , inputs )
278+ torch .equal (train_input , input ) for train_input , input in zip (train_inputs , inputs , strict = True )
280279 ):
281280 raise RuntimeError ("You must train on the training inputs!" )
282281 res = super ().__call__ (* inputs , ** kwargs )
@@ -294,7 +293,9 @@ def __call__(self, *args, **kwargs):
294293 # Posterior mode
295294 else :
296295 if settings .debug .on ():
297- if all (torch .equal (train_input , input ) for train_input , input in length_safe_zip (train_inputs , inputs )):
296+ if all (
297+ torch .equal (train_input , input ) for train_input , input in zip (train_inputs , inputs , strict = True )
298+ ):
298299 warnings .warn (
299300 "The input matches the stored training data. Did you forget to call model.train()?" ,
300301 GPInputWarning ,
@@ -382,7 +383,7 @@ def _get_test_prior_mean_and_covariances(
382383 # Concatenate the input to the training input
383384 full_inputs = []
384385 batch_shape = train_inputs [0 ].shape [:- 2 ]
385- for train_input , input in length_safe_zip (train_inputs , test_inputs ):
386+ for train_input , input in zip (train_inputs , test_inputs , strict = True ):
386387 # Make sure the batch shapes agree for training/test data
387388 if batch_shape != train_input .shape [:- 2 ]:
388389 batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
0 commit comments