Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tester/api_config/config_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import copy
import math
import os
import random
import re

import numpy
import paddle
import torch

USE_CACHED_NUMPY = os.getenv("USE_CACHED_NUMPY", "False").lower() == "true"
TEST_NON_CONTIGUOUS = os.getenv("TEST_NON_CONTIGUOUS", "0").lower() in ("true", "1")
cached_numpy = {}

not_zero_apis = frozenset(
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(self, shape, dtype, place=None):
self.numpy_tensor = None
self.paddle_tensor = None
self.torch_tensor = None
self.shuffle_dims = None

def __deepcopy__(self, memo):
cls = self.__class__
Expand Down Expand Up @@ -2642,6 +2645,13 @@ def get_paddle_tensor(self, api_config):
self.paddle_tensor.stop_gradient = False
if self.dtype == "bfloat16":
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="bfloat16")
if TEST_NON_CONTIGUOUS:
if not self.shuffle_dims:
ndim = self.paddle_tensor.dim()
self.shuffle_dims = list(range(ndim))
random.shuffle(self.shuffle_dims)
print("paddle shuffle:", self.shuffle_dims)
return paddle.transpose(self.paddle_tensor, self.shuffle_dims)
return self.paddle_tensor

def get_torch_tensor(self, api_config):
Expand Down Expand Up @@ -2669,6 +2679,13 @@ def get_torch_tensor(self, api_config):
)
if self.dtype == "bfloat16":
self.torch_tensor = self.torch_tensor.to(dtype=torch.bfloat16)
if TEST_NON_CONTIGUOUS:
if not self.shuffle_dims:
ndim = self.torch_tensor.dim()
self.shuffle_dims = list(range(ndim))
random.shuffle(self.shuffle_dims)
print("torch shuffle:", self.shuffle_dims)
return torch.permute(self.torch_tensor, self.shuffle_dims)
return self.torch_tensor

def clear_tensor(self):
Expand Down