Skip to content

Commit a3c3440

Browse files
committed
sampling_kernels: add missing device arg
1 parent f93d017 commit a3c3440

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/torchlensmaker/sampling/sampling_kernels.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def example_inputs(
170170
def example_params(
171171
dtype: torch.dtype, device: torch.device
172172
) -> tuple[torch.Tensor, ...]:
173-
return (torch.tensor(10, dtype=torch.int64),)
173+
return (torch.tensor(10, dtype=torch.int64, device=device),)
174174

175175

176176
class LinspaceSampling2DKernel(FunctionalKernel):
@@ -206,8 +206,8 @@ def example_params(
206206
dtype: torch.dtype, device: torch.device
207207
) -> tuple[torch.Tensor, ...]:
208208
return (
209-
torch.tensor(10, dtype=torch.int64),
210-
torch.tensor(11, dtype=torch.int64),
209+
torch.tensor(10, dtype=torch.int64, device=device),
210+
torch.tensor(11, dtype=torch.int64, device=device),
211211
)
212212

213213

@@ -240,6 +240,6 @@ def example_params(
240240
dtype: torch.dtype, device: torch.device
241241
) -> tuple[torch.Tensor, ...]:
242242
return (
243-
torch.tensor(10, dtype=torch.int64),
244-
torch.tensor(11, dtype=torch.int64),
243+
torch.tensor(10, dtype=torch.int64, device=device),
244+
torch.tensor(11, dtype=torch.int64, device=device),
245245
)

0 commit comments

Comments
 (0)