Skip to content

Commit 9446cdf

Browse files
Replace length_safe_zip with zip(..., strict=True)
1 parent c4ee5f5 commit 9446cdf

6 files changed

Lines changed: 21 additions & 37 deletions

File tree

gpytorch/kernels/lcm_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch.nn import ModuleList
88

99
from ..priors import Prior
10-
from ..utils.generic import length_safe_zip
1110
from .kernel import Kernel
1211
from .multitask_kernel import MultitaskKernel
1312

@@ -50,7 +49,7 @@ def __init__(
5049
self.covar_module_list = ModuleList(
5150
[
5251
MultitaskKernel(base_kernel, num_tasks=num_tasks, rank=r, task_covar_prior=task_covar_prior)
53-
for base_kernel, r in length_safe_zip(base_kernels, rank)
52+
for base_kernel, r in zip(base_kernels, rank, strict=True)
5453
]
5554
)
5655

gpytorch/likelihoods/likelihood_list.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from torch.nn import ModuleList
44

55
from gpytorch.likelihoods import Likelihood
6-
from gpytorch.utils.generic import length_safe_zip
76

87

98
def _get_tuple_args_(*args):
@@ -22,7 +21,7 @@ def __init__(self, *likelihoods):
2221
def expected_log_prob(self, *args, **kwargs):
2322
return [
2423
likelihood.expected_log_prob(*args_, **kwargs)
25-
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
24+
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args), strict=True)
2625
]
2726

2827
def forward(self, *args, **kwargs):
@@ -31,18 +30,18 @@ def forward(self, *args, **kwargs):
3130
# if noise kwarg is passed, assume it's an iterable of noise tensors
3231
return [
3332
likelihood.forward(*args_, {**kwargs, "noise": noise_})
34-
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
33+
for likelihood, args_, noise_ in zip(self.likelihoods, _get_tuple_args_(*args), noise, strict=True)
3534
]
3635
else:
3736
return [
3837
likelihood.forward(*args_, **kwargs)
39-
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
38+
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args), strict=True)
4039
]
4140

4241
def pyro_sample_output(self, *args, **kwargs):
4342
return [
4443
likelihood.pyro_sample_output(*args_, **kwargs)
45-
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
44+
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args), strict=True)
4645
]
4746

4847
def __call__(self, *args, **kwargs):
@@ -51,10 +50,10 @@ def __call__(self, *args, **kwargs):
5150
# if noise kwarg is passed, assume it's an iterable of noise tensors
5251
return [
5352
likelihood(*args_, {**kwargs, "noise": noise_})
54-
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
53+
for likelihood, args_, noise_ in zip(self.likelihoods, _get_tuple_args_(*args), noise, strict=True)
5554
]
5655
else:
5756
return [
5857
likelihood(*args_, **kwargs)
59-
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
58+
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args), strict=True)
6059
]

gpytorch/mlls/sum_marginal_log_likelihood.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from torch.nn import ModuleList
44

55
from gpytorch.mlls import ExactMarginalLogLikelihood, MarginalLogLikelihood
6-
from gpytorch.utils.generic import length_safe_zip
76

87

98
class SumMarginalLogLikelihood(MarginalLogLikelihood):
@@ -31,10 +30,10 @@ def forward(self, outputs, targets, *params):
3130
(e.g. parameters in case of heteroskedastic likelihoods)
3231
"""
3332
if len(params) == 0:
34-
sum_mll = sum(mll(output, target) for mll, output, target in length_safe_zip(self.mlls, outputs, targets))
33+
sum_mll = sum(mll(output, target) for mll, output, target in zip(self.mlls, outputs, targets, strict=True))
3534
else:
3635
sum_mll = sum(
3736
mll(output, target, *iparams)
38-
for mll, output, target, iparams in length_safe_zip(self.mlls, outputs, targets, params)
37+
for mll, output, target, iparams in zip(self.mlls, outputs, targets, params, strict=True)
3938
)
4039
return sum_mll.div_(len(self.mlls))

gpytorch/models/exact_gp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .. import settings
1717
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
1818
from ..likelihoods import _GaussianLikelihoodBase
19-
from ..utils.generic import length_safe_zip
2019
from ..utils.warnings import GPInputWarning
2120
from .exact_prediction_strategies import prediction_strategy
2221
from .gp import GP
@@ -129,7 +128,7 @@ def set_train_data(
129128
inputs = (inputs,)
130129
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
131130
if strict:
132-
for input_, t_input in length_safe_zip(inputs, self.train_inputs or (None,)):
131+
for input_, t_input in zip(inputs, self.train_inputs or (None,), strict=True):
133132
for attr in {"shape", "dtype", "device"}:
134133
expected_attr = getattr(t_input, attr, None)
135134
found_attr = getattr(input_, attr, None)
@@ -222,7 +221,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
222221
[train_input, input.expand(input_batch_shape + input.shape[-2:])],
223222
dim=-2,
224223
)
225-
for train_input, input in length_safe_zip(train_inputs, inputs)
224+
for train_input, input in zip(train_inputs, inputs, strict=True)
226225
]
227226
full_targets = torch.cat(
228227
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
@@ -277,7 +276,7 @@ def __call__(self, *args, **kwargs):
277276
)
278277
if settings.debug.on():
279278
if not all(
280-
torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
279+
torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs, strict=True)
281280
):
282281
raise RuntimeError("You must train on the training inputs!")
283282
res = super().__call__(*inputs, **kwargs)
@@ -295,7 +294,7 @@ def __call__(self, *args, **kwargs):
295294
# Posterior mode
296295
else:
297296
if settings.debug.on():
298-
if all(torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)):
297+
if all(torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs, strict=True)):
299298
warnings.warn(
300299
"The input matches the stored training data. Did you forget to call model.train()?",
301300
GPInputWarning,
@@ -381,7 +380,7 @@ def _get_test_prior_mean_and_covariances(
381380
# Concatenate the input to the training input
382381
full_inputs = []
383382
batch_shape = train_inputs[0].shape[:-2]
384-
for train_input, input in length_safe_zip(train_inputs, test_inputs):
383+
for train_input, input in zip(train_inputs, test_inputs, strict=True):
385384
# Make sure the batch shapes agree for training/test data
386385
if batch_shape != train_input.shape[:-2]:
387386
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])

gpytorch/models/model_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from gpytorch.likelihoods import LikelihoodList
99
from gpytorch.models import GP
10-
from gpytorch.utils.generic import length_safe_zip
1110

1211

1312
class AbstractModelList(GP, ABC):
@@ -39,7 +38,7 @@ def likelihood_i(self, i, *args, **kwargs):
3938

4039
def forward(self, *args, **kwargs):
4140
return [
42-
model.forward(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
41+
model.forward(*args_, **kwargs) for model, args_ in zip(self.models, _get_tensor_args(*args), strict=True)
4342
]
4443

4544
def get_fantasy_model(self, inputs, targets, **kwargs):
@@ -66,18 +65,19 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
6665

6766
fantasy_models = [
6867
model.get_fantasy_model(*inputs_, *targets_, **kwargs_)
69-
for model, inputs_, targets_, kwargs_ in length_safe_zip(
68+
for model, inputs_, targets_, kwargs_ in zip(
7069
self.models,
7170
_get_tensor_args(*inputs),
7271
_get_tensor_args(*targets),
7372
kwargs,
73+
strict=True,
7474
)
7575
]
7676
return self.__class__(*fantasy_models)
7777

7878
def __call__(self, *args, **kwargs):
7979
return [
80-
model.__call__(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
80+
model.__call__(*args_, **kwargs) for model, args_ in zip(self.models, _get_tensor_args(*args), strict=True)
8181
]
8282

8383
@property

gpytorch/utils/generic.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,6 @@
22

33
from __future__ import annotations
44

5-
6-
def length_safe_zip(*args):
7-
"""Python's `zip(...)` with checks to ensure the arguments have
8-
the same number of elements.
9-
10-
NOTE: This converts all args that do not define "__len__" to a list.
11-
"""
12-
args = [a if hasattr(a, "__len__") else list(a) for a in args]
13-
if len({len(a) for a in args}) > 1:
14-
raise ValueError(
15-
"Expected the lengths of all arguments to be equal. Got lengths "
16-
f"{[len(a) for a in args]} for args {args}. Did you pass in "
17-
"fewer inputs than expected?"
18-
)
19-
return zip(*args)
5+
# This module previously contained the `length_safe_zip` function.
6+
# That function has been removed in favor of using Python's built-in
7+
# `zip(..., strict=True)` which provides the same length-checking behavior.

0 commit comments

Comments
 (0)