Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/layer/arm/binaryop_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -308,6 +309,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
2 changes: 2 additions & 0 deletions src/layer/arm/binaryop_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x))
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)powf(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2f(x, y), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y))))))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2f(y, x), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf(x, y), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y))))))
// *INDENT-ON*
// clang-format on

Expand All @@ -352,6 +353,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s<binary_op_rpow_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s<binary_op_atan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s<binary_op_ratan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s<binary_op_remainder_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
11 changes: 11 additions & 0 deletions src/layer/arm/neon_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,5 +395,16 @@ static inline float32x4_t atan2_ps(float32x4_t a, float32x4_t b)
return vld1q_f32(tmpx);
}

static inline float32x4_t remainder_ps(float32x4_t x, float32x4_t y)
{
float tmpx[4];
float tmpy[4];
vst1q_f32(tmpx, x);
vst1q_f32(tmpy, y);
for (int i = 0; i < 4; i++)
tmpx[i] = remainderf(tmpx[i], tmpy[i]);
return vld1q_f32(tmpx);
}

#include "neon_mathfun_tanh.h"
#endif // NEON_MATHFUN_H
13 changes: 13 additions & 0 deletions src/layer/binaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ struct binary_op_ratan2
}
};

struct binary_op_remainder
{
float operator()(const float& x, const float& y) const
{
const float div_result = x / y;
const float floor_result = floorf(div_result);
const float mul_result = floor_result * y;
return x - mul_result;
}
};

static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
{
if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast<binary_op_add>(a, b, c, opt);
Expand All @@ -251,6 +262,7 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast<binary_op_atan2>(b, a, c, opt);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_broadcast<binary_op_remainder>(a, b, c, opt);

// should never reach here
}
Expand All @@ -269,6 +281,7 @@ static void binary_op_scalar_inplace(Mat& bottom_top_blob, float b, int op_type,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_scalar_inplace<binary_op_remainder>(bottom_top_blob, b, opt);

// should never reach here
}
Expand Down
3 changes: 2 additions & 1 deletion src/layer/binaryop.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class BinaryOp : public Layer
Operation_RDIV = 8,
Operation_RPOW = 9,
Operation_ATAN2 = 10,
Operation_RATAN2 = 11
Operation_RATAN2 = 11,
Operation_REMAINDER = 12
};

public:
Expand Down
2 changes: 2 additions & 0 deletions src/layer/loongarch/binaryop_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
13 changes: 13 additions & 0 deletions src/layer/loongarch/lsx_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,17 @@ static inline __m128 atan2_ps(__m128 a, __m128 b)
return (__m128)__lsx_vld(tmpx, 0);
}

static inline __m128 remainder_ps(__m128 x, __m128 y)
{
float tmpx[4];
float tmpy[4];
__lsx_vst(x, tmpx, 0);
__lsx_vst(y, tmpy, 0);
tmpx[0] = remainderf(tmpx[0], tmpy[0]);
tmpx[1] = remainderf(tmpx[1], tmpy[1]);
tmpx[2] = remainderf(tmpx[2], tmpy[2]);
tmpx[3] = remainderf(tmpx[3], tmpy[3]);
return (__m128)__lsx_vld(tmpx, 0);
}

#endif // LSX_MATHFUN_H
2 changes: 2 additions & 0 deletions src/layer/mips/binaryop_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
13 changes: 13 additions & 0 deletions src/layer/mips/msa_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,17 @@ static inline v4f32 atan2_ps(v4f32 a, v4f32 b)
return (v4f32)__msa_ld_w(tmpx, 0);
}

static inline v4f32 remainder_ps(v4f32 x, v4f32 y)
{
float tmpx[4];
float tmpy[4];
__msa_st_w((v4i32)x, tmpx, 0);
__msa_st_w((v4i32)y, tmpy, 0);
tmpx[0] = remainderf(tmpx[0], tmpy[0]);
tmpx[1] = remainderf(tmpx[1], tmpy[1]);
tmpx[2] = remainderf(tmpx[2], tmpy[2]);
tmpx[3] = remainderf(tmpx[3], tmpy[3]);
return (v4f32)__msa_ld_w(tmpx, 0);
}

#endif // MSA_MATHFUN_H
4 changes: 4 additions & 0 deletions src/layer/riscv/binaryop_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f32m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f32m8(y, vl), vl), atan2_ps(vfmv_v_f_f32m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f32m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f32m8(x, vl), vl))
MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f32m8(y, vl), vl), remainder_ps(vfmv_v_f_f32m8(x, vl), y, vl))
// *INDENT-ON*
// clang-format on

Expand All @@ -316,6 +317,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down Expand Up @@ -887,6 +889,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow((float)y, (float)x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f16m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2((float)x, (float)y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f16m8(y, vl), vl), atan2_ps(vfmv_v_f_f16m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2((float)y, (float)x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f16m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f16m8(x, vl), vl))
MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf((float)x, (float)y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f16m8(y, vl), vl), remainder_ps(vfmv_v_f_f16m8(x, vl), y, vl))
// *INDENT-ON*
// clang-format on

Expand All @@ -910,6 +913,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s<binary_op_rpow_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s<binary_op_atan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s<binary_op_ratan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s<binary_op_remainder_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
19 changes: 19 additions & 0 deletions src/layer/riscv/rvv_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,23 @@ _RVV_FLOAT32_ATAN2_OP(2, 16)
_RVV_FLOAT32_ATAN2_OP(4, 8)
_RVV_FLOAT32_ATAN2_OP(8, 4)

#define _RVV_FLOAT32_REMAINDER_OP(LMUL, MLEN) \
static inline vfloat32m##LMUL##_t remainder_ps(vfloat32m##LMUL##_t x, vfloat32m##LMUL##_t y, size_t vl) \
{ \
std::vector<float> tmpx(vl); \
std::vector<float> tmpy(vl); \
vse32_v_f32m##LMUL(tmpx.data(), x, vl); \
vse32_v_f32m##LMUL(tmpy.data(), y, vl); \
for (size_t i = 0; i < vl; i++) \
{ \
tmpx[i] = remainderf(tmpx[i], tmpy[i]); \
} \
return vle32_v_f32m##LMUL(tmpx.data(), vl); \
}

_RVV_FLOAT32_REMAINDER_OP(1, 32)
_RVV_FLOAT32_REMAINDER_OP(2, 16)
_RVV_FLOAT32_REMAINDER_OP(4, 8)
_RVV_FLOAT32_REMAINDER_OP(8, 4)

#endif // RVV_MATHFUN_H
19 changes: 19 additions & 0 deletions src/layer/riscv/rvv_mathfun_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,23 @@ _RVV_FLOAT16_ATAN2_OP(2, 16)
_RVV_FLOAT16_ATAN2_OP(4, 8)
_RVV_FLOAT16_ATAN2_OP(8, 4)

#define _RVV_FLOAT16_REMAINDER_OP(LMUL, MLEN) \
static inline vfloat16m##LMUL##_t remainder_ps(vfloat16m##LMUL##_t x, vfloat16m##LMUL##_t y, size_t vl) \
{ \
std::vector<__fp16> tmpx(vl); \
std::vector<__fp16> tmpy(vl); \
vse16_v_f16m##LMUL(tmpx.data(), x, vl); \
vse16_v_f16m##LMUL(tmpy.data(), y, vl); \
for (size_t i = 0; i < vl; i++) \
{ \
tmpx[i] = (__fp16)remainderf((float)tmpx[i], (float)tmpy[i]); \
} \
return vle16_v_f16m##LMUL(tmpx.data(), vl); \
}

_RVV_FLOAT16_REMAINDER_OP(1, 32)
_RVV_FLOAT16_REMAINDER_OP(2, 16)
_RVV_FLOAT16_REMAINDER_OP(4, 8)
_RVV_FLOAT16_REMAINDER_OP(8, 4)

#endif // RVV_MATHFUN_FP16S_H
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop.comp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast.comp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
8 changes: 8 additions & 0 deletions src/layer/x86/avx512_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -856,4 +856,12 @@ static NCNN_FORCEINLINE __m512 abs512_ps(__m512 x)
return _mm512_andnot_ps(magic_negative_zero, x);
}

static NCNN_FORCEINLINE __m512 remainder512_ps(__m512 x, __m512 y)
{
const __m512 div_result = _mm512_div_ps(x, y);
const __m512 floor_result = _mm512_floor_ps(div_result);
const __m512 mul_result = _mm512_mul_ps(y, floor_result);
return _mm512_sub_ps(x, mul_result);
}

#endif // AVX512_MATHFUN_H
8 changes: 8 additions & 0 deletions src/layer/x86/avx_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -1087,4 +1087,12 @@ static NCNN_FORCEINLINE __m256 abs256_ps(__m256 x)
return _mm256_andnot_ps(magic_negative_zero, x);
}

static NCNN_FORCEINLINE __m256 remainder256_ps(__m256 x, __m256 y)
{
const __m256 div_result = _mm256_div_ps(x, y);
const __m256 floor_result = _mm256_floor_ps(div_result);
const __m256 mul_result = _mm256_mul_ps(y, floor_result);
return _mm256_sub_ps(x, mul_result);
}

#endif // AVX_MATHFUN_H
30 changes: 30 additions & 0 deletions src/layer/x86/binaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,35 @@ struct binary_op_ratan2
#endif // __SSE2__
};

struct binary_op_remainder
{
float func(const float& x, const float& y) const
{
const float div_result = x / y;
const float floor_result = floorf(div_result);
const float mul_result = floor_result * y;
return x - mul_result;
}
#if __SSE2__
__m128 func_pack4(const __m128& x, const __m128& y) const
{
return remainder_ps(x, y);
}
#if __AVX__
__m256 func_pack8(const __m256& x, const __m256& y) const
{
return remainder256_ps(x, y);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x, const __m512& y) const
{
return remainder512_ps(x, y);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace BinaryOp_x86_functor

static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type)
Expand All @@ -807,6 +836,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
Loading