77
88namespace mlx ::core {
99
10+ void PrintFormatter::print (std::ostream& os, bool val) {
11+ if (capitalize_bool) {
12+ os << (val ? " True" : " False" );
13+ } else {
14+ os << val;
15+ }
16+ }
17+ inline void PrintFormatter::print (std::ostream& os, int16_t val) {
18+ os << val;
19+ }
20+ inline void PrintFormatter::print (std::ostream& os, uint16_t val) {
21+ os << val;
22+ }
23+ inline void PrintFormatter::print (std::ostream& os, int32_t val) {
24+ os << val;
25+ }
26+ inline void PrintFormatter::print (std::ostream& os, uint32_t val) {
27+ os << val;
28+ }
29+ inline void PrintFormatter::print (std::ostream& os, int64_t val) {
30+ os << val;
31+ }
32+ inline void PrintFormatter::print (std::ostream& os, uint64_t val) {
33+ os << val;
34+ }
35+ inline void PrintFormatter::print (std::ostream& os, float16_t val) {
36+ os << val;
37+ }
38+ inline void PrintFormatter::print (std::ostream& os, bfloat16_t val) {
39+ os << val;
40+ }
41+ inline void PrintFormatter::print (std::ostream& os, float val) {
42+ os << val;
43+ }
44+ inline void PrintFormatter::print (std::ostream& os, complex64_t val) {
45+ os << val;
46+ }
47+
48+ PrintFormatter global_formatter;
49+
1050Dtype result_type (const std::vector<array>& arrays) {
1151 std::vector<Dtype> dtypes (1 , bool_);
1252 for (auto & arr : arrays) {
@@ -136,7 +176,7 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
136176 i = n - num_print - 1 ;
137177 index += s * (n - 2 * num_print - 1 );
138178 } else if (is_last) {
139- os << a.data <T>()[index];
179+ global_formatter. print (os, a.data <T>()[index]) ;
140180 } else {
141181 print_subarray<T>(os, a, index, dim + 1 );
142182 }
@@ -153,7 +193,7 @@ void print_array(std::ostream& os, const array& a) {
153193 os << " array(" ;
154194 if (a.ndim () == 0 ) {
155195 auto data = a.data <T>();
156- os << data[0 ];
196+ global_formatter. print (os, data[0 ]) ;
157197 } else {
158198 print_subarray<T>(os, a, 0 , 0 );
159199 }
0 commit comments