Skip to content
Open
3 changes: 2 additions & 1 deletion benchmarks/src/replace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ void rc(benchmark::State& state) {
}
}

// replace() is vectorized for 4 and 8 bytes only.
BENCHMARK(r<std::uint8_t>);
BENCHMARK(r<std::uint16_t>);
BENCHMARK(r<std::uint32_t>);
BENCHMARK(r<std::uint64_t>);

Expand Down
44 changes: 35 additions & 9 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ __declspec(noalias) bool __stdcall __std_includes_less_8u(
#endif // ^^^ _VECTORIZED_INCLUDES ^^^

#if _VECTORIZED_REPLACE
#if _VECTORIZED_REPLACE_1_2
__declspec(noalias) void __stdcall __std_replace_1(
void* _First, void* _Last, uint8_t _Old_val, uint8_t _New_val) noexcept;
__declspec(noalias) void __stdcall __std_replace_2(
void* _First, void* _Last, uint16_t _Old_val, uint16_t _New_val) noexcept;
#endif // ^^^ _VECTORIZED_REPLACE_1_2 ^^^

// TRANSITION, DevCom-10610477
__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept;
Expand Down Expand Up @@ -383,14 +390,25 @@ bool _Includes_vectorized(
template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
if constexpr (sizeof(_Ty) == 4) {
::__std_replace_4(
_First, _Last, _STD _Find_arg_cast<uint32_t>(_Old_val), _STD _Find_arg_cast<uint32_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 8) {
::__std_replace_8(
_First, _Last, _STD _Find_arg_cast<uint64_t>(_Old_val), _STD _Find_arg_cast<uint64_t>(_New_val));
} else {
static_assert(false, "unexpected size");
#if _VECTORIZED_REPLACE_1_2
if constexpr (sizeof(_Ty) == 1) {
::__std_replace_1(
_First, _Last, _STD _Find_arg_cast<uint8_t>(_Old_val), _STD _Find_arg_cast<uint8_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 2) {
::__std_replace_2(
_First, _Last, _STD _Find_arg_cast<uint16_t>(_Old_val), _STD _Find_arg_cast<uint16_t>(_New_val));
} else
#endif // ^^^ _VECTORIZED_REPLACE_1_2 ^^^
{
if constexpr (sizeof(_Ty) == 4) {
::__std_replace_4(
_First, _Last, _STD _Find_arg_cast<uint32_t>(_Old_val), _STD _Find_arg_cast<uint32_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 8) {
::__std_replace_8(
_First, _Last, _STD _Find_arg_cast<uint64_t>(_Old_val), _STD _Find_arg_cast<uint64_t>(_New_val));
} else {
static_assert(false, "unexpected size");
}
}
}
#endif // ^^^ _VECTORIZED_REPLACE ^^^
Expand Down Expand Up @@ -491,10 +509,18 @@ _Ty* _Unique_copy_vectorized(const _Ty* const _First, const _Ty* const _Last, _T
#endif // ^^^ _VECTORIZED_UNIQUE_COPY ^^^

#if _VECTORIZED_REPLACE
#if _VECTORIZED_REPLACE_1_2
template <class _Iter>
constexpr bool _Have_masked_op_for_iter = true;
#else // ^^^ _VECTORIZED_REPLACE_1_2 / !_VECTORIZED_REPLACE_1_2 vvv
template <class _Iter>
constexpr bool _Have_masked_op_for_iter = sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size
#endif // ^^^ !_VECTORIZED_REPLACE_1_2 ^^^
Comment thread
StephanTLavavej marked this conversation as resolved.

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1>
constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value
&& sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size
&& _Have_masked_op_for_iter<_Iter>;

// Can we activate the vector algorithms for ranges::replace?
template <class _Iter, class _Ty1, class _Ty2>
Expand Down
8 changes: 7 additions & 1 deletion stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ _STL_DISABLE_CLANG_WARNINGS
#define _VECTORIZED_MISMATCH _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REMOVE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REMOVE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86
#define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REPLACE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REVERSE _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
#define _VECTORIZED_REVERSE_COPY _VECTORIZED_FOR_X64_X86_ARM64_ARM64EC
Expand All @@ -104,6 +104,12 @@ _STL_DISABLE_CLANG_WARNINGS
// as this does not improve performance over the scalar code.
#define _VECTORIZED_MINMAX_ELEMENT_64BIT_INT _VECTORIZED_FOR_X64_X86

#if defined(_M_ARM64) || defined(_M_ARM64EC)
#define _VECTORIZED_REPLACE_1_2 1
#else
#define _VECTORIZED_REPLACE_1_2 0
#endif

#ifndef _USE_STD_VECTOR_FLOATING_ALGORITHMS
#if _USE_STD_VECTOR_ALGORITHMS && !defined(_M_FP_EXCEPT)
#define _USE_STD_VECTOR_FLOATING_ALGORITHMS 1
Expand Down
145 changes: 135 additions & 10 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#if defined(_M_ARM64) || defined(_M_ARM64EC)
#include <arm64_neon.h>
#include <arm_sve.h>

#include <Windows.h>
#else // ^^^ defined(_M_ARM64) || defined(_M_ARM64EC) / !defined(_M_ARM64) && !defined(_M_ARM64EC) vvv
Expand Down Expand Up @@ -9607,6 +9608,116 @@ __declspec(noalias) size_t __stdcall __std_mismatch_8(
namespace {
namespace _Replacing {
#if defined(_M_ARM64) || defined(_M_ARM64EC)
struct _Traits_1_sve {
static svuint8_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept {
return svld1(_Pred, static_cast<const uint8_t*>(_Ptr));
}

static svuint8_t _Set(const uint8_t _Val) noexcept {
return svdup_n_u8(_Val);
}

static svbool_t _Cmp(const svbool_t _Pred, const svuint8_t _Lhs, const svuint8_t _Rhs) noexcept {
return svcmpeq(_Pred, _Lhs, _Rhs);
}

static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint8_t _Val) noexcept {
svst1(_Pred, static_cast<uint8_t*>(_Ptr), _Val);
}
};

struct _Traits_2_sve {
static svuint16_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept {
return svld1(_Pred, static_cast<const uint16_t*>(_Ptr));
}

static svuint16_t _Set(const uint16_t _Val) noexcept {
return svdup_n_u16(_Val);
}

static svbool_t _Cmp(const svbool_t _Pred, const svuint16_t _Lhs, const svuint16_t _Rhs) noexcept {
return svcmpeq(_Pred, _Lhs, _Rhs);
}

static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint16_t _Val) noexcept {
svst1(_Pred, static_cast<uint16_t*>(_Ptr), _Val);
}
};

struct _Traits_4_sve {
static svuint32_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept {
return svld1(_Pred, static_cast<const uint32_t*>(_Ptr));
}

static svuint32_t _Set(const uint32_t _Val) noexcept {
return svdup_n_u32(_Val);
}

static svbool_t _Cmp(const svbool_t _Pred, const svuint32_t _Lhs, const svuint32_t _Rhs) noexcept {
return svcmpeq(_Pred, _Lhs, _Rhs);
}

static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint32_t _Val) noexcept {
svst1(_Pred, static_cast<uint32_t*>(_Ptr), _Val);
}
};

struct _Traits_8_sve {
static svuint64_t _Load(const svbool_t _Pred, const void* const _Ptr) noexcept {
return svld1(_Pred, static_cast<const uint64_t*>(_Ptr));
}

static svuint64_t _Set(const uint64_t _Val) noexcept {
return svdup_n_u64(_Val);
}

static svbool_t _Cmp(const svbool_t _Pred, const svuint64_t _Lhs, const svuint64_t _Rhs) noexcept {
return svcmpeq(_Pred, _Lhs, _Rhs);
}

static void _Store(const svbool_t _Pred, void* const _Ptr, const svuint64_t _Val) noexcept {
svst1(_Pred, static_cast<uint64_t*>(_Ptr), _Val);
}
};

template <class _Traits, class _Ty>
__declspec(noalias) void __stdcall _Replace_impl(
void* _First, void* const _Last, const _Ty _Old_val, const _Ty _New_val) noexcept {

if (_Use_FEAT_SVE()) {
const size_t _Sve_vl = svcntb();
const size_t _Size_bytes = _Byte_length(_First, _Last);
const size_t _Full_vl_bytes = _Size_bytes & ~size_t{_Sve_vl - 1};

Comment thread
StephanTLavavej marked this conversation as resolved.
const void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_vl_bytes);

const auto _Comparand = _Traits::_Set(_Old_val);
const auto _Replacement = _Traits::_Set(_New_val);

const auto _True = svptrue_b8();
while (_First != _Stop_at) {
const auto _Data = _Traits::_Load(_True, _First);
const auto _Mask = _Traits::_Cmp(_True, _Data, _Comparand);
_Traits::_Store(_Mask, _First, _Replacement);
_Advance_bytes(_First, _Sve_vl);
}

if (const size_t _Tail_length = _Size_bytes & size_t{_Sve_vl - 1}; _Tail_length != 0) {
const auto _Tail_mask = svwhilelt_b8(size_t{0}, _Tail_length);
const auto _Data = _Traits::_Load(_Tail_mask, _First);
const auto _Mask = _Traits::_Cmp(_Tail_mask, _Data, _Comparand);
_Traits::_Store(_Mask, _First, _Replacement);
}
} else {
for (auto _Cur = static_cast<_Ty*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

template <class _Traits, class _Ty>
__declspec(noalias) void __stdcall _Replace_copy_impl(
const void* _First, const void* const _Last, void* _Dest, const _Ty _Old_val, const _Ty _New_val) noexcept {
Expand Down Expand Up @@ -9745,10 +9856,29 @@ namespace {

extern "C" {

#ifndef _M_ARM64
#if defined(_M_ARM64) || defined(_M_ARM64EC)
__declspec(noalias) void __stdcall __std_replace_1(
void* const _First, void* const _Last, const uint8_t _Old_val, const uint8_t _New_val) noexcept {
_Replacing::_Replace_impl<_Replacing::_Traits_1_sve>(_First, _Last, _Old_val, _New_val);
}

__declspec(noalias) void __stdcall __std_replace_2(
void* const _First, void* const _Last, const uint16_t _Old_val, const uint16_t _New_val) noexcept {
_Replacing::_Replace_impl<_Replacing::_Traits_2_sve>(_First, _Last, _Old_val, _New_val);
}

__declspec(noalias) void __stdcall __std_replace_4(
void* const _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept {
_Replacing::_Replace_impl<_Replacing::_Traits_4_sve>(_First, _Last, _Old_val, _New_val);
}

__declspec(noalias) void __stdcall __std_replace_8(
void* const _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept {
_Replacing::_Replace_impl<_Replacing::_Traits_8_sve>(_First, _Last, _Old_val, _New_val);
}
#else // ^^^ defined(_M_ARM64) || defined(_M_ARM64EC) / !defined(_M_ARM64) && !defined(_M_ARM64EC) vvv
__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
const __m256i _Comparand = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_New_val));
Expand All @@ -9773,9 +9903,7 @@ __declspec(noalias) void __stdcall __std_replace_4(
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414
} else
#endif // ^^^ !defined(_M_ARM64EC) ^^^
{
} else {
for (auto _Cur = reinterpret_cast<uint32_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
Expand All @@ -9786,7 +9914,6 @@ __declspec(noalias) void __stdcall __std_replace_4(

__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
#ifdef _WIN64
const __m256i _Comparand = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_Old_val));
Expand Down Expand Up @@ -9816,17 +9943,15 @@ __declspec(noalias) void __stdcall __std_replace_8(
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414
} else
#endif // ^^^ !defined(_M_ARM64EC) ^^^
{
} else {
for (auto _Cur = reinterpret_cast<uint64_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}
#endif // ^^^ !defined(_M_ARM64) ^^^
#endif // ^^^ !defined(_M_ARM64) && !defined(_M_ARM64EC) ^^^

__declspec(noalias) void __stdcall __std_replace_copy_1(const void* const _First, const void* const _Last,
void* const _Dest, const uint8_t _Old_val, const uint8_t _New_val) noexcept {
Expand Down
5 changes: 0 additions & 5 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1961,11 +1961,6 @@ int main() {
test_min_max_element<unsigned long long>(gen);

test_min_max_element_pointers(gen);

test_replace<int>(gen);
test_replace<unsigned int>(gen);
test_replace<long long>(gen);
test_replace<unsigned long long>(gen);
#else // ^^^ defined(_CALL_ALL_X64_VECTOR_ALGORITHMS_ON_ARM64EC) / normal test coverage vvv
test_vector_algorithms(gen);
test_various_containers();
Expand Down