Skip to content

Commit 59228ad

Browse files
generatedunixname89002005307016facebook-github-bot
authored andcommitted
Add type error suppressions for upcoming upgrade
Reviewed By: MaggieMoss Differential Revision: D77242941 fbshipit-source-id: 08b3645f9b4189ebc2da2ed80c9f71abe4d5926d
1 parent 127141a commit 59228ad

File tree

4 files changed

+8
-1
lines changed

4 files changed

+8
-1
lines changed

flowtorch/bijectors/permute.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _inverse(
4646
if self.permutation is None:
4747
self.permutation = torch.randperm(y.shape[-1])
4848

49+
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got `Optional[Tensor]`.
4950
x = torch.index_select(y, -1, self.inv_permutation)
5051
ladj = self._log_abs_det_jacobian(x, y, params)
5152
return x, ladj

flowtorch/distributions/flow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
# TODO: Confirm that the following logic works. Shouldn't it use
3131
# .domain and .codomain?? Infer shape from constructed self.bijector
3232
shape = (
33+
# pyre-fixme[6]: For 1st argument expected `tuple[int, ...]` but got `Size`.
3334
self.base_dist.batch_shape
3435
# pyre-fixme[58]: `+` is not supported for operand types `Size` and `Size`.
3536
+ self.base_dist.event_shape

flowtorch/distributions/neals_funnel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ def rsample(
3434
if not sample_shape:
3535
sample_shape = torch.Size()
3636
eps = _standard_normal(
37-
(sample_shape[0], 2), dtype=torch.float, device=torch.device("cpu")
37+
# pyre-fixme[6]: For 1st argument expected `Sequence[Union[int,
38+
# SymInt]]` but got `Tuple[Union[int, Tensor], int]`.
39+
(sample_shape[0], 2),
40+
dtype=torch.float,
41+
device=torch.device("cpu"),
3842
)
3943
z = torch.zeros(eps.shape)
4044
z[..., 1] = torch.tensor(3.0) * eps[..., 1]

flowtorch/parameters/dense_autoregressive.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _forward(
182182
# results ~ (batch_shape, param_shapes[0]), ...
183183
result = tuple(
184184
# pyre-fixme[58]: `+` is not supported for operand types `Size` and `Size`.
185+
# pyre-fixme[6]: For 1st argument expected `tuple[int, ...]` but got `Size`.
185186
h_slice.view(batch_shape + p_shape)
186187
for h_slice, p_shape in zip(result, list(self.param_shapes))
187188
)

0 commit comments

Comments
 (0)