Skip to content

Commit fe6adf8

Browse files
committed
Playing with Highway Gw
1 parent c320853 commit fe6adf8

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

common/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ drake_cc_library(
406406
# x86: SIMD with the improved 512VL encodings.
407407
# TODO(jwnimmer-tri) Enable this once we support opt-out of CPU
408408
# targets via NPY_DISABLE_CPU_FEATURES environment variable.
409-
# "HWY_AVX3",
409+
"HWY_AVX3",
410410
# arm: SIMD that has fixed 256-bit lanes.
411411
# TODO(jwnimmer-tri) Enable this once we can test it.
412412
# "HWY_SVE_256",

math/fast_pose_composition_functions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ they partially overlap or overlap with R_AB.
7777
This requires 30 floating point operations but can be done very efficiently
7878
exploing SIMD instructions when available. */
7979
void ReexpressSpatialVector(const RotationMatrix<double>& R_AB,
80-
const Vector6<double>& V_B,
81-
Vector6<double>* V_A);
80+
const Vector6<double>& V_B, Vector6<double>* V_A);
8281

8382
void CrossProduct(const Vector3<double>& w, const Vector3<double>& r,
8483
Vector3<double>* wXr);

math/fast_pose_composition_functions_avx2_fma.cc

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,7 @@ We want to perform two matrix-vector products:
616616
We can do this in 6 SIMD instructions. We end up doing 40 flops and throwing
617617
10 of them away.
618618
*/
619-
void ReexpressSpatialVectorImpl(const double* R_AB,
620-
const double* V_B,
619+
void ReexpressSpatialVectorImpl(const double* R_AB, const double* V_B,
621620
double* V_A) {
622621
const hn::FixedTag<double, 4> tag;
623622

@@ -647,7 +646,6 @@ void ReexpressSpatialVectorImpl(const double* R_AB,
647646
hn::StoreN(RST_, tag, V_A + 3, 3); // 3-wide write to stay in bounds
648647
}
649648

650-
651649
// w = uvw, r = xyz
652650
// w X v = vz - wy
653651
// wx - uz
@@ -663,12 +661,51 @@ void CrossProductImpl(const double* w, const double* r, double* wXr) {
663661
const auto yzx_ = hn::Per4LaneBlockShuffle<3, 0, 2, 1>(xyz_); // r120
664662
const auto zxy_ = hn::Per4LaneBlockShuffle<3, 1, 0, 2>(xyz_); // r201
665663

666-
const auto right = hn::Mul(wuv_, yzx_); // w201 * r120
664+
const auto right = hn::Mul(wuv_, yzx_); // w201 * r120
667665
const auto wXr_ = hn::MulSub(vwu_, zxy_, right); // w120*r201 - right
668666

669667
hn::StoreN(wXr_, tag, wXr, 3);
670668
}
671669

670+
671+
// w x w x r
672+
void CrossCrossProductImpl(const double* w, const double* r, double* wXwXr) {
673+
const hn::FixedTag<double, 4> tag;
674+
675+
const auto uvw_ = hn::LoadN(tag, w, 3);
676+
const auto vwu_ = hn::Per4LaneBlockShuffle<3, 0, 2, 1>(uvw_); // w120
677+
const auto wuv_ = hn::Per4LaneBlockShuffle<3, 1, 0, 2>(uvw_); // w201
678+
679+
const auto xyz_ = hn::LoadN(tag, r, 3);
680+
const auto yzx_ = hn::Per4LaneBlockShuffle<3, 0, 2, 1>(xyz_); // r120
681+
const auto zxy_ = hn::Per4LaneBlockShuffle<3, 1, 0, 2>(xyz_); // r201
682+
683+
const auto right = hn::Mul(wuv_, yzx_); // w201 * r120
684+
const auto wXr_ = hn::MulSub(vwu_, zxy_, right); // w120*r201 - right
685+
686+
const auto wXr_120 = hn::Per4LaneBlockShuffle<3, 0, 2, 1>(wXr_);
687+
const auto wXr_201 = hn::Per4LaneBlockShuffle<3, 1, 0, 2>(wXr_);
688+
689+
const auto wXwXr_right = hn::Mul(wuv_, wXr_120); // w201 * wxr_120
690+
const auto wXwXr_ = hn::MulSub(vwu_, wXr_201, wXwXr_right); // w120 * wxr_201
691+
692+
hn::StoreN(wXwXr_, tag, wXwXr, 3);
693+
}
694+
695+
// G is a - - but symmetric
696+
// b d -
697+
// c e f
698+
void SymTimesVectorImpl(const double* G, const double* w, double* Gw) {
699+
const hn::FixedTag<double, 4> tag;
700+
const auto uvw_ = hn::LoadN(tag, w, 3);
701+
const auto abcx = hn::LoadU(tag, G);
702+
const auto cxde = hn::LoadU(tag, G + 2);
703+
const auto abde = hn::ConcatUpperLower(tag, cxde, abcx);
704+
const auto bde0 = hn::ShiftLeftLanes<1>(tag, abde);
705+
hn::StoreU(bde0, tag, Gw);
706+
// TODO(sherm1) Need cef_.
707+
}
708+
672709
#else // HWY_MAX_BYTES
673710

674711
/* The portable versions are always defined. They should be written to maximize
@@ -776,19 +813,22 @@ void ComposeXinvXImpl(const double* X_BA, const double* X_BC, double* X_AC) {
776813
std::copy(X_AC_temp, X_AC_temp + 12, X_AC);
777814
}
778815

779-
void ReexpressSpatialVectorImpl(const double* R_AB,
780-
const double* V_B,
816+
void ReexpressSpatialVectorImpl(const double* R_AB, const double* V_B,
781817
double* V_A) {
782818
DRAKE_ASSERT(V_A != nullptr);
783819
double x, y, z; // Protect from overlap with V_B.
784820
x = row_x_col(&R_AB[0], &V_B[0]);
785821
y = row_x_col(&R_AB[1], &V_B[0]);
786822
z = row_x_col(&R_AB[2], &V_B[0]);
787-
V_A[0] = x; V_A[1] = y; V_A[2] = z;
823+
V_A[0] = x;
824+
V_A[1] = y;
825+
V_A[2] = z;
788826
x = row_x_col(&R_AB[0], &V_B[3]);
789827
y = row_x_col(&R_AB[1], &V_B[3]);
790828
z = row_x_col(&R_AB[2], &V_B[3]);
791-
V_A[3] = x; V_A[4] = y; V_A[5] = z;
829+
V_A[3] = x;
830+
V_A[4] = y;
831+
V_A[5] = z;
792832
}
793833

794834
// w = uvw, r = xyz
@@ -903,16 +943,15 @@ void ComposeXinvX(const RigidTransform<double>& X_BA,
903943
}
904944

905945
void ReexpressSpatialVector(const RotationMatrix<double>& R_AB,
906-
const Vector6<double>& V_B,
907-
Vector6<double>* V_A) {
946+
const Vector6<double>& V_B, Vector6<double>* V_A) {
908947
LateBoundFunction<ChooseBestReexpressSpatialVector>::Call(
909948
GetRawData(R_AB), GetRawData(V_B), GetRawData(V_A));
910949
}
911950

912951
void CrossProduct(const Vector3<double>& w, const Vector3<double>& r,
913952
Vector3<double>* wXr) {
914-
LateBoundFunction<ChooseBestShuffleVector>::Call(
915-
GetRawData(w), GetRawData(r), GetRawData(wXr));
953+
LateBoundFunction<ChooseBestShuffleVector>::Call(GetRawData(w), GetRawData(r),
954+
GetRawData(wXr));
916955
}
917956

918957
} // namespace internal

0 commit comments

Comments
 (0)