Skip to content

Commit cc05a28

Browse files
authored
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation * Cleanup, bug fixes from code review * Minor cleanup, fixed Linux tests
1 parent fe96cee commit cc05a28

File tree

16 files changed

+143
-1
lines changed

16 files changed

+143
-1
lines changed

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Operations
1919
arcsin
2020
arcsinh
2121
arctan
22+
arctan2
2223
arctanh
2324
argmax
2425
argmin

mlx/backend/accelerate/primitives.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
193193
}
194194
}
195195

196+
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
197+
assert(inputs.size() == 2);
198+
auto& a = inputs[0];
199+
auto& b = inputs[1];
200+
if (out.dtype() == float32 && a.flags().row_contiguous &&
201+
b.flags().row_contiguous) {
202+
if (a.is_donatable()) {
203+
out.copy_shared_buffer(a);
204+
} else if (b.is_donatable()) {
205+
out.copy_shared_buffer(b);
206+
} else {
207+
out.set_data(allocator::malloc_or_wait(out.nbytes()));
208+
}
209+
int size = a.data_size();
210+
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
211+
} else {
212+
eval(inputs, out);
213+
}
214+
}
215+
196216
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
197217
assert(inputs.size() == 1);
198218
const auto& in = inputs[0];

mlx/backend/common/binary.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
293293
}
294294
}
295295

296+
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
297+
assert(inputs.size() == 2);
298+
const auto& a = inputs[0];
299+
const auto& b = inputs[1];
300+
if (out.dtype() == float32) {
301+
binary_op<float>(a, b, out, detail::ArcTan2());
302+
} else if (out.dtype() == float16) {
303+
binary_op<float16_t>(a, b, out, detail::ArcTan2());
304+
} else if (out.dtype() == bfloat16) {
305+
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
306+
} else if (issubdtype(out.dtype(), inexact)) {
307+
std::ostringstream err;
308+
err << "[arctan2] Does not support " << out.dtype();
309+
throw std::invalid_argument(err.str());
310+
} else {
311+
throw std::invalid_argument(
312+
"[arctan2] Cannot compute inverse tangent for arrays"
313+
" with non floating point type.");
314+
}
315+
}
316+
296317
} // namespace mlx::core

mlx/backend/common/default_primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
3434
DEFAULT(ArcSin)
3535
DEFAULT(ArcSinh)
3636
DEFAULT(ArcTan)
37+
DEFAULT(ArcTan2)
3738
DEFAULT(ArcTanh)
3839
DEFAULT(ArgPartition)
3940
DEFAULT(ArgReduce)

mlx/backend/common/ops.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ struct ArcTan {
161161
};
162162
};
163163

164+
struct ArcTan2 {
165+
template <typename T>
166+
T operator()(T y, T x) {
167+
return std::atan2(y, x);
168+
};
169+
};
170+
164171
struct ArcTanh {
165172
template <typename T>
166173
T operator()(T x) {

mlx/backend/metal/kernels/binary.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,10 @@ struct RightShift {
264264
return x >> y;
265265
};
266266
};
267+
268+
struct ArcTan2 {
269+
template <typename T>
270+
T operator()(T y, T x) {
271+
return metal::precise::atan2(y, x);
272+
}
273+
};

mlx/backend/metal/kernels/binary.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
241241
instantiate_binary_types(sub, Subtract)
242242
instantiate_binary_types(pow, Power)
243243
instantiate_binary_types(rem, Remainder)
244+
instantiate_binary_float(arctan2, ArcTan2)
244245

245246
// NaNEqual only needed for floating point types with boolean output
246247
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)

mlx/backend/metal/primitives.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
451451
unary_op(inputs, out, "arctan");
452452
}
453453

454+
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
455+
binary_op(inputs, out, "arctan2");
456+
}
457+
454458
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
455459
unary_op(inputs, out, "arctanh");
456460
}

mlx/backend/no_metal/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ NO_GPU(ArcCosh)
2525
NO_GPU(ArcSin)
2626
NO_GPU(ArcSinh)
2727
NO_GPU(ArcTan)
28+
NO_GPU(ArcTan2)
2829
NO_GPU(ArcTanh)
2930
NO_GPU(ArgPartition)
3031
NO_GPU(ArgReduce)

mlx/compile.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) {
4545
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
4646
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
4747
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
48-
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
48+
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
49+
typeid(p) == typeid(ArcTan2));
4950
}
5051

5152
bool is_ternary(const Primitive& p) {

0 commit comments

Comments
 (0)