Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/matx/operators/interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ namespace matx {
template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, [[maybe_unused]] Executor &&ex) const {

// Forward PreRun to the operands to support generic operators/transforms
if constexpr (is_matx_op<OpX>()) {
x_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpV>()) {
v_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpXQ>()) {
xq_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

// Allocate temporary storage for spline coefficients
if (method_ == InterpMethod::SPLINE) {
static_assert(is_cuda_executor_v<Executor>, "cubic spline interpolation only supports the CUDA executor currently");
Expand Down Expand Up @@ -481,6 +492,16 @@ namespace matx {
template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]] Executor &&ex) const noexcept {
if constexpr (is_matx_op<OpX>()) {
x_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpV>()) {
v_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpXQ>()) {
xq_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if (method_ == InterpMethod::SPLINE) {
matxFree(ptr_m_);
}
Expand Down
11 changes: 10 additions & 1 deletion include/matx/operators/polyval.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ namespace matx
if constexpr (is_matx_op<Op>()) {
op_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<Coeffs>()) {
coeffs_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
Expand All @@ -168,6 +172,10 @@ namespace matx
if constexpr (is_matx_op<Op>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<Coeffs>()) {
coeffs_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <OperatorCapability Cap, typename InType>
Expand All @@ -183,7 +191,8 @@ namespace matx
}
else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) {
#ifdef MATX_EN_JIT
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in));
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in),
detail::get_operator_capability<Cap>(coeffs_, in));
#else
return false;
#endif
Expand Down
11 changes: 10 additions & 1 deletion include/matx/operators/remap.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<IdxType>()) {
idx_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
Expand All @@ -239,6 +243,10 @@ namespace matx
if constexpr (is_matx_op<T>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<IdxType>()) {
idx_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <OperatorCapability Cap, typename InType>
Expand All @@ -254,7 +262,8 @@ namespace matx
}
else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) {
#ifdef MATX_EN_JIT
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in));
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in),
detail::get_operator_capability<Cap>(idx_, in));
#else
return false;
#endif
Expand Down
23 changes: 20 additions & 3 deletions include/matx/operators/sar_bp.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ namespace detail {
template <OperatorCapability Cap>
__MATX_INLINE__ __MATX_HOST__ auto get_capability() const {
auto self_has_cap = capability_attributes<Cap>::default_value;
return combine_capabilities<Cap>(self_has_cap,
return combine_capabilities<Cap>(self_has_cap,
detail::get_operator_capability<Cap>(initial_image_),
detail::get_operator_capability<Cap>(range_profiles_),
detail::get_operator_capability<Cap>(platform_positions_),
detail::get_operator_capability<Cap>(voxel_locations_),
detail::get_operator_capability<Cap>(range_to_mcp_));
}

Expand Down Expand Up @@ -276,7 +277,15 @@ namespace detail {
if constexpr (is_matx_op<PlatPosType>()) {
platform_positions_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

if constexpr (is_matx_op<VoxLocType>()) {
voxel_locations_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<RangeToMcpType>()) {
range_to_mcp_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
Expand All @@ -303,8 +312,16 @@ namespace detail {
platform_positions_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<VoxLocType>()) {
voxel_locations_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<RangeToMcpType>()) {
range_to_mcp_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

matxFree(ptr);
}
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
Expand Down
11 changes: 10 additions & 1 deletion include/matx/operators/shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ namespace matx
}
else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) {
#ifdef MATX_EN_JIT
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in));
return combine_capabilities<Cap>(true, detail::get_operator_capability<Cap>(op_, in),
detail::get_operator_capability<Cap>(shift_, in));
#else
return false;
#endif
Expand Down Expand Up @@ -276,6 +277,10 @@ namespace matx
if constexpr (is_matx_op<T1>()) {
op_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<T2>()) {
shift_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
Expand All @@ -284,6 +289,10 @@ namespace matx
if constexpr (is_matx_op<T1>()) {
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<T2>()) {
shift_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down
53 changes: 53 additions & 0 deletions test/00_operators/interp_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "matx.h"
#include "test_types.h"
#include "utilities.h"
#include "prerun_tester.h"

using namespace matx;
using namespace matx::test;
Expand Down Expand Up @@ -256,3 +257,55 @@ TEST(InterpTests, Interp)

MATX_EXIT_HANDLER();
}

// Verify that interp1 forwards PreRun/PostRun to ALL of its operands (sample
// points, sample values, query points) by wrapping each one in a lifecycle
// probe. A single invocation covers every forwarded operand.
template <typename ExecType>
static void InterpForwardingCheck(ExecType exec)
{
using TestType = float;

auto x = make_tensor<TestType>({5});
x.SetVals({0.0, 1.0, 3.0, 3.5, 4.0});
auto v = make_tensor<TestType>({5});
v.SetVals({0.0, 2.0, 1.0, 3.0, 4.0});
auto xq = make_tensor<TestType>({6});
xq.SetVals({-1.0, 0.0, 0.25, 1.0, 1.5, 5.0});

// Reference computed from the raw operands.
auto out_ref = make_tensor<TestType>({xq.Size(0)});
(out_ref = interp1(x, v, xq, InterpMethod::LINEAR)).run(exec);

// Wrap every forwarded operand in a lifecycle probe.
PreRunLifecycle sx, sv, sxq;
auto out_test = make_tensor<TestType>({xq.Size(0)});
(out_test = interp1(make_prerun_tester(x, sx), make_prerun_tester(v, sv),
make_prerun_tester(xq, sxq), InterpMethod::LINEAR))
.run(exec);
exec.sync();

// The forwarding contract is validated by the lifecycle counters: each
// operand's PreRun/PostRun must have been forwarded exactly once.
ExpectLifecycleClean(sx, "x");
ExpectLifecycleClean(sv, "v");
ExpectLifecycleClean(sxq, "xq");

// The probe forwards to the wrapped operand transparently, so out_test always
// equals out_ref regardless of forwarding; this only confirms the probe is
// transparent. The lifecycle counters above are what test the forwarding fix.
for (index_t i = 0; i < xq.Size(0); i++) {
ASSERT_NEAR(out_test(i), out_ref(i), 1e-4) << "mismatch at index " << i;
}
}

TEST(InterpTests, InterpOperatorInput)
{
MATX_ENTER_HANDLER();
// Run on the CUDA executor only: interp1 requires CUDA for the SPLINE method,
// and the template instantiation for a host executor triggers the SPLINE
// static_assert even for LINEAR. The host executor is already exercised by the
// pre-existing InterpTests.Interp test.
InterpForwardingCheck(cudaExecutor{});
MATX_EXIT_HANDLER();
}
42 changes: 41 additions & 1 deletion test/00_operators/polyval_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "matx.h"
#include "test_types.h"
#include "utilities.h"
#include "prerun_tester.h"

using namespace matx;
using namespace matx::test;
Expand Down Expand Up @@ -31,4 +32,43 @@ TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecs, PolyVal)
MATX_TEST_ASSERT_COMPARE(pb, out, "out", 0.01);

MATX_EXIT_HANDLER();
}
}

// Verify polyval forwards PreRun/PostRun to both of its operands (input values
// and coefficients) by wrapping each in a lifecycle probe.
TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecsWithoutJIT, PolyValOperatorCoeffs)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

constexpr int N = 5;
constexpr int NC = 4;

auto x = make_tensor<TestType>({N});
auto c = make_tensor<TestType>({NC});
x.SetVals({0.5, 1.0, 1.5, 2.0, 2.5});
c.SetVals({1, 2, 3, 4});

// Reference from the raw operands.
auto out_ref = make_tensor<TestType>({N});
(out_ref = polyval(x, c)).run(exec);

// Wrap both operands in lifecycle probes.
PreRunLifecycle sx, sc;
auto out_test = make_tensor<TestType>({N});
(out_test = polyval(make_prerun_tester(x, sx), make_prerun_tester(c, sc))).run(exec);
exec.sync();

ExpectLifecycleClean(sx, "input");
ExpectLifecycleClean(sc, "coeffs");

for (int i = 0; i < N; i++) {
ASSERT_NEAR(static_cast<double>(out_test(i)), static_cast<double>(out_ref(i)),
1e-4) << "mismatch at index " << i;
}

MATX_EXIT_HANDLER();
}
49 changes: 47 additions & 2 deletions test/00_operators/remap_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "matx.h"
#include "test_types.h"
#include "utilities.h"
#include "prerun_tester.h"

using namespace matx;
using namespace matx::test;
Expand Down Expand Up @@ -350,8 +351,52 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapOp)
}
}
}
}
}
}

MATX_EXIT_HANDLER();
}

// Verify remap forwards PreRun/PostRun to both the data operand and the index
// operand by wrapping each in a lifecycle probe.
TYPED_TEST(OperatorTestsNumericAllExecsWithoutJIT, RemapOperatorIndex)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
using inner_type = typename inner_op_type_t<TestType>::type;

ExecType exec{};

constexpr int N = 5;
auto tiv = make_tensor<TestType>({N, N});
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
tiv(i, j) = inner_type(i * N + j);
}
}

auto idx = make_tensor<int>({3});
idx.SetVals({1, 2, 3}); // select source rows 1, 2, 3

// Reference from the raw operands.
auto out_ref = make_tensor<TestType>({3, N});
(out_ref = remap<0>(tiv, idx)).run(exec);

// Wrap both the data operand and the index operand in lifecycle probes.
PreRunLifecycle sd, si;
auto out_test = make_tensor<TestType>({3, N});
(out_test = remap<0>(make_prerun_tester(tiv, sd), make_prerun_tester(idx, si))).run(exec);
exec.sync();

ExpectLifecycleClean(sd, "data");
ExpectLifecycleClean(si, "index");

for (int i = 0; i < 3; i++) {
for (int j = 0; j < N; j++) {
ASSERT_EQ(out_test(i, j), out_ref(i, j)) << "mismatch at (" << i << "," << j << ")";
}
}

MATX_EXIT_HANDLER();
}
}
Loading