forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathForeachUtils.h
105 lines (83 loc) · 3.46 KB
/
ForeachUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#pragma once
#include <ATen/ATen.h>
namespace at {
namespace native {
namespace {
// Set of foreach API restrictions
// - All tensors must be of the same dtype
// - All corresponding tensors must be of the same size
void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_dtype = tensors[0].dtype();
for (const auto& t : tensors) {
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
auto expected_dtype = tensors1[0].dtype();
for (int i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
}
}
// To go via 'fast' path, several conditions must be satisfied
// - All tensors must have strided layout
// - All tensors must be non-overlapping and dense
// - All tensors must be on the same device
// - Resulting tensor must have the same dtype as the input one
bool can_use_fast_route(TensorList tensors, Scalar scalar) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_device = tensors[0].device();
for (auto t : tensors) {
if (t.device() != expected_device) {
return false;
}
if (t.layout() != at::kStrided) {
return false;
}
if (!t.is_non_overlapping_and_dense()) {
return false;
}
// complex scalar + integral or boolean tensor will result in complex tensor
if (scalar.isComplex() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) {
return false;
}
// float scalar + integral or boolean tensor will result in float tensor
if (scalar.isFloatingPoint() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) {
return false;
}
// integral scalar + boolean tensor will result in integral tensor
if (scalar.isIntegral(/*includeBool*/ false) && t.dtype() == at::kBool) {
return false;
}
}
return true;
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
auto expected_device = tensors1[0].device();
for (int64_t i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors from tensor lists have different size.");
if (tensors1[i].device() != expected_device ||
tensors2[i].device() != expected_device) {
return false;
}
if (tensors1[i].layout() != at::kStrided ||
tensors2[i].layout() != at::kStrided) {
return false;
}
if (tensors1[i].strides() != tensors2[i].strides()) {
return false;
}
if (!tensors1[i].is_non_overlapping_and_dense() ||
!tensors2[i].is_non_overlapping_and_dense()) {
return false;
}
}
return true;
}
}
}} // at::native