Skip to content

Commit 7beaafd

Browse files
committed
Add non-contiguous test.
1 parent 979589e commit 7beaafd

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tester/api_config/config_analyzer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import numpy
1010
import paddle
1111
import torch
12+
import random
1213

1314
USE_CACHED_NUMPY = os.getenv("USE_CACHED_NUMPY", "False").lower() == "true"
15+
TEST_NON_CONTIGUOUS = os.getenv("TEST_NON_CONTIGUOUS", "False").lower() == "true"
1416
cached_numpy = {}
1517

1618
not_zero_apis = frozenset(
@@ -84,6 +86,7 @@ def __init__(self, shape, dtype, place=None):
8486
self.numpy_tensor = None
8587
self.paddle_tensor = None
8688
self.torch_tensor = None
89+
self.shuffle_dims = None
8790

8891
def __deepcopy__(self, memo):
8992
cls = self.__class__
@@ -2642,6 +2645,13 @@ def get_paddle_tensor(self, api_config):
26422645
self.paddle_tensor.stop_gradient = False
26432646
if self.dtype == "bfloat16":
26442647
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="bfloat16")
2648+
if TEST_NON_CONTIGUOUS:
2649+
if not self.shuffle_dims:
2650+
ndim = self.paddle_tensor.dim()
2651+
self.shuffle_dims = list(range(ndim))
2652+
random.shuffle(self.shuffle_dims)
2653+
print("paddle shuffle:", self.shuffle_dims)
2654+
return paddle.transpose(self.paddle_tensor, self.shuffle_dims)
26452655
return self.paddle_tensor
26462656

26472657
def get_torch_tensor(self, api_config):
@@ -2669,6 +2679,13 @@ def get_torch_tensor(self, api_config):
26692679
)
26702680
if self.dtype == "bfloat16":
26712681
self.torch_tensor = self.torch_tensor.to(dtype=torch.bfloat16)
2682+
if TEST_NON_CONTIGUOUS:
2683+
if not self.shuffle_dims:
2684+
ndim = self.torch_tensor.dim()
2685+
self.shuffle_dims = list(range(ndim))
2686+
random.shuffle(self.shuffle_dims)
2687+
print("torch shuffle:", self.shuffle_dims)
2688+
return torch.permute(self.torch_tensor, self.shuffle_dims)
26722689
return self.torch_tensor
26732690

26742691
def clear_tensor(self):

0 commit comments

Comments
 (0)