1616from .. import settings
1717from ..distributions import MultitaskMultivariateNormal , MultivariateNormal
1818from ..likelihoods import _GaussianLikelihoodBase
19- from ..utils .generic import length_safe_zip
2019from ..utils .warnings import GPInputWarning
2120from .exact_prediction_strategies import prediction_strategy
2221from .gp import GP
@@ -129,7 +128,7 @@ def set_train_data(
129128 inputs = (inputs ,)
130129 inputs = tuple (input_ .unsqueeze (- 1 ) if input_ .ndimension () == 1 else input_ for input_ in inputs )
131130 if strict :
132- for input_ , t_input in length_safe_zip (inputs , self .train_inputs or (None ,)):
131+ for input_ , t_input in zip (inputs , self .train_inputs or (None ,), strict = True ):
133132 for attr in {"shape" , "dtype" , "device" }:
134133 expected_attr = getattr (t_input , attr , None )
135134 found_attr = getattr (input_ , attr , None )
@@ -222,7 +221,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
222221 [train_input , input .expand (input_batch_shape + input .shape [- 2 :])],
223222 dim = - 2 ,
224223 )
225- for train_input , input in length_safe_zip (train_inputs , inputs )
224+ for train_input , input in zip (train_inputs , inputs , strict = True )
226225 ]
227226 full_targets = torch .cat (
228227 [train_targets , targets .expand (target_batch_shape + targets .shape [data_dim_start :])], dim = data_dim_start
@@ -277,7 +276,7 @@ def __call__(self, *args, **kwargs):
277276 )
278277 if settings .debug .on ():
279278 if not all (
280- torch .equal (train_input , input ) for train_input , input in length_safe_zip (train_inputs , inputs )
279+ torch .equal (train_input , input ) for train_input , input in zip (train_inputs , inputs , strict = True )
281280 ):
282281 raise RuntimeError ("You must train on the training inputs!" )
283282 res = super ().__call__ (* inputs , ** kwargs )
@@ -295,7 +294,7 @@ def __call__(self, *args, **kwargs):
295294 # Posterior mode
296295 else :
297296 if settings .debug .on ():
298- if all (torch .equal (train_input , input ) for train_input , input in length_safe_zip (train_inputs , inputs )):
297+ if all (torch .equal (train_input , input ) for train_input , input in zip (train_inputs , inputs , strict = True )):
299298 warnings .warn (
300299 "The input matches the stored training data. Did you forget to call model.train()?" ,
301300 GPInputWarning ,
@@ -381,7 +380,7 @@ def _get_test_prior_mean_and_covariances(
381380 # Concatenate the input to the training input
382381 full_inputs = []
383382 batch_shape = train_inputs [0 ].shape [:- 2 ]
384- for train_input , input in length_safe_zip (train_inputs , test_inputs ):
383+ for train_input , input in zip (train_inputs , test_inputs , strict = True ):
385384 # Make sure the batch shapes agree for training/test data
386385 if batch_shape != train_input .shape [:- 2 ]:
387386 batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
0 commit comments