|
9 | 9 | import numpy |
10 | 10 | import paddle |
11 | 11 | import torch |
| 12 | +import random |
12 | 13 |
|
13 | 14 | USE_CACHED_NUMPY = os.getenv("USE_CACHED_NUMPY", "False").lower() == "true" |
| 15 | +TEST_NON_CONTIGUOUS = os.getenv("TEST_NON_CONTIGUOUS", "False").lower() == "true" |
14 | 16 | cached_numpy = {} |
15 | 17 |
|
16 | 18 | not_zero_apis = frozenset( |
@@ -84,6 +86,7 @@ def __init__(self, shape, dtype, place=None): |
84 | 86 | self.numpy_tensor = None |
85 | 87 | self.paddle_tensor = None |
86 | 88 | self.torch_tensor = None |
| 89 | + self.shuffle_dims = None |
87 | 90 |
|
88 | 91 | def __deepcopy__(self, memo): |
89 | 92 | cls = self.__class__ |
@@ -2642,6 +2645,13 @@ def get_paddle_tensor(self, api_config): |
2642 | 2645 | self.paddle_tensor.stop_gradient = False |
2643 | 2646 | if self.dtype == "bfloat16": |
2644 | 2647 | self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="bfloat16") |
| 2648 | + if TEST_NON_CONTIGUOUS: |
| 2649 | + if not self.shuffle_dims: |
| 2650 | + ndim = self.paddle_tensor.dim() |
| 2651 | + self.shuffle_dims = list(range(ndim)) |
| 2652 | + random.shuffle(self.shuffle_dims) |
| 2653 | + print("paddle shuffle:", self.shuffle_dims) |
| 2654 | + return paddle.transpose(self.paddle_tensor, self.shuffle_dims) |
2645 | 2655 | return self.paddle_tensor |
2646 | 2656 |
|
2647 | 2657 | def get_torch_tensor(self, api_config): |
@@ -2669,6 +2679,13 @@ def get_torch_tensor(self, api_config): |
2669 | 2679 | ) |
2670 | 2680 | if self.dtype == "bfloat16": |
2671 | 2681 | self.torch_tensor = self.torch_tensor.to(dtype=torch.bfloat16) |
| 2682 | + if TEST_NON_CONTIGUOUS: |
| 2683 | + if not self.shuffle_dims: |
| 2684 | + ndim = self.torch_tensor.dim() |
| 2685 | + self.shuffle_dims = list(range(ndim)) |
| 2686 | + random.shuffle(self.shuffle_dims) |
| 2687 | + print("torch shuffle:", self.shuffle_dims) |
| 2688 | + return torch.permute(self.torch_tensor, self.shuffle_dims) |
2672 | 2689 | return self.torch_tensor |
2673 | 2690 |
|
2674 | 2691 | def clear_tensor(self): |
|
0 commit comments