@@ -170,7 +170,7 @@ def example_inputs(
170170 def example_params (
171171 dtype : torch .dtype , device : torch .device
172172 ) -> tuple [torch .Tensor , ...]:
173- return (torch .tensor (10 , dtype = torch .int64 ),)
173+ return (torch .tensor (10 , dtype = torch .int64 , device = device ),)
174174
175175
176176class LinspaceSampling2DKernel (FunctionalKernel ):
@@ -206,8 +206,8 @@ def example_params(
206206 dtype : torch .dtype , device : torch .device
207207 ) -> tuple [torch .Tensor , ...]:
208208 return (
209- torch .tensor (10 , dtype = torch .int64 ),
210- torch .tensor (11 , dtype = torch .int64 ),
209+ torch .tensor (10 , dtype = torch .int64 , device = device ),
210+ torch .tensor (11 , dtype = torch .int64 , device = device ),
211211 )
212212
213213
@@ -240,6 +240,6 @@ def example_params(
240240 dtype : torch .dtype , device : torch .device
241241 ) -> tuple [torch .Tensor , ...]:
242242 return (
243- torch .tensor (10 , dtype = torch .int64 ),
244- torch .tensor (11 , dtype = torch .int64 ),
243+ torch .tensor (10 , dtype = torch .int64 , device = device ),
244+ torch .tensor (11 , dtype = torch .int64 , device = device ),
245245 )
0 commit comments