forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathivalue.cpp
116 lines (101 loc) · 3.05 KB
/
ivalue.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/Formatting.h>
#include <cmath>
namespace c10 {
namespace ivalue {
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
std::string str_) {
return c10::make_intrusive<ConstantString>(std::move(str_));
}
} // namespace ivalue
namespace {
template<typename List>
std::ostream& printList(std::ostream & out, const List &v,
const std::string start, const std::string finish) {
out << start;
for(size_t i = 0; i < v->elements().size(); ++i) {
if(i > 0)
out << ", ";
// make sure we use ivalue printing, and not default printing for the element type
out << IValue(v->elements()[i]);
}
out << finish;
return out;
}
template<typename Dict>
std::ostream& printDict(std::ostream& out, const Dict& v) {
out << "{";
bool first = true;
for (const auto& pair : v->elements()) {
if (!first) {
out << ", ";
}
out << pair.first << ": " << pair.second;
first = false;
}
out << "}";
return out;
}
} // anonymous namespace
std::ostream& operator<<(std::ostream & out, const IValue & v) {
switch(v.tag) {
case IValue::Tag::None:
return out << v.toNone();
case IValue::Tag::Tensor:
return out << v.toTensor();
case IValue::Tag::Double: {
double d = v.toDouble();
int c = std::fpclassify(d);
if (c == FP_NORMAL || c == FP_ZERO) {
int64_t i = int64_t(d);
if (double(i) == d) {
return out << i << ".";
}
}
auto orig_prec = out.precision();
return out
<< std::setprecision(std::numeric_limits<double>::max_digits10)
<< v.toDouble()
<< std::setprecision(orig_prec);
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::Bool:
return out << (v.toBool() ? "True" : "False");
case IValue::Tag::Tuple:
return printList(out, v.toTuple(), "(", ")");
case IValue::Tag::IntList:
return printList(out, v.toIntList(), "[", "]");
case IValue::Tag::DoubleList:
return printList(out, v.toDoubleList(), "[", "]");
case IValue::Tag::BoolList:
return printList(out, v.toBoolList(), "[", "]");
case IValue::Tag::String:
return out << v.toStringRef();
case IValue::Tag::TensorList:
return printList(out, v.toTensorList(), "[", "]");
case IValue::Tag::Blob:
return out << *v.toBlob();
case IValue::Tag::GenericList:
return printList(out, v.toGenericList(), "[", "]");
case IValue::Tag::Future:
return out << "Future";
case IValue::Tag::Device:
return out << v.toDevice();
case IValue::Tag::GenericDict:
return printDict(out, v.toGenericDict());
case IValue::Tag::Object:
// TODO we should print the object contents
return out << "Object<" << v.toObject()->name()
<< ">";
}
AT_ERROR("Tag not found\n");
}
#undef TORCH_FORALL_TAGS
void IValue::dump() const {
std::cout << *this << "\n";
}
const std::string& ivalue::Object::name() const {
return this->type_->name();
}
} // namespace c10