Skip to content

tblis_tensor_add with conjugated output tensor #87

@Thoemi09

Description

@Thoemi09

Dear TBLIS developers,

I think there might be a problem with tblis_tensor_add when the resulting tensor B is conjugated.

Issue

Here is a minimal example:

#include <tblis/tblis.h>

#include <complex>
#include <iostream>
#include <vector>

int main() {
  // calculate b <-- alpha * a + beta * conj(b) using tblis_tensor_add
  // with alpha = 1, beta = 1, a = [(0, 1), (0, 2)], b = [(0, 3), (0, 4)]
  auto A     = std::vector<std::complex<double>>{{0, 1}, {0, 2}};
  auto B     = std::vector<std::complex<double>>{{0, 3}, {0, 4}};
  auto B_res = B; // copy of B to hold the result of the TBLIS add
  auto alpha = std::complex<double>{1, 0};
  auto beta  = std::complex<double>{1, 0};

  // set up TBLIS tensors
  auto len = std::vector<tblis::len_type>{2};
  auto str = std::vector<tblis::stride_type>{1};
  tblis::tblis_tensor tA(alpha, false, A.data(), 1, len.data(), str.data());
  tblis::tblis_tensor tB(beta, true, B_res.data(), 1, len.data(), str.data()); // conjugate B

  // perform vector addtion
  tblis::tblis_tensor_add(nullptr, nullptr, &tA, "i", &tB, "i");

  // expected result: b = [(0, -2), (0, -2)]
  std::cout << "Result: [" << B_res[0] << ", " << B_res[1] << "]\n";
  std::cout << "Expected: [" << alpha * A[0] + beta * std::conj(B[0]) << ", " << alpha * A[1] + beta * std::conj(B[1]) << "]\n";
}

On my machine (MacBook Pro with macOS Tahoe) this gives

Result: [(6.49672e-314,1), (8.66229e-314,2)]
Expected: [(0,-2), (0,-2)]

Looking at the implementation, I would assume that it should calculate B... = alpha * A... + beta * op(B...), where op(...) is either the identity or complex conjugation.

Possible fix

Applying the following changes seems to fix the problem:

diff --git a/tblis/frame/1t/dense/add.cxx b/tblis/frame/1t/dense/add.cxx
index 40a8df6..cab255c 100644
--- a/tblis/frame/1t/dense/add.cxx
+++ b/tblis/frame/1t/dense/add.cxx
@@ -198,7 +198,7 @@ void add(type_t type, const communicator& comm, const cntx_t* cntx,
     if (unit_A_AB == unit_B_AB)
     {
         auto add_ukr = reinterpret_cast<axpbyv_ker_ft>(bli_cntx_get_ukr_dt((num_t)type, BLIS_AXPBYV_KER, cntx));
-        auto scal_ukr = reinterpret_cast<scalv_ker_ft>(bli_cntx_get_ukr_dt((num_t)type, BLIS_SCALV_KER, cntx));
+        auto scal_ukr = reinterpret_cast<scal2v_ker_ft>(bli_cntx_get_ukr_dt((num_t)type, BLIS_SCAL2V_KER, cntx));
         auto one = bli_obj_buffer_for_const((num_t)type, &BLIS_ONE);

         comm.distribute_over_threads(n0, mn1,
@@ -217,7 +217,7 @@ void add(type_t type, const communicator& comm, const cntx_t* cntx,
                 iter_AB.next(A1, B1);

                 if (conj_B)
-                    scal_ukr(BLIS_CONJUGATE, n0_max-n0_min, &one, B1, stride_B_m, cntx);
+                    scal_ukr(BLIS_CONJUGATE, n0_max-n0_min, one, B1, stride_B_m, B1, stride_B_m, cntx);

                 add_ukr(conj_A ? BLIS_CONJUGATE : BLIS_NO_CONJUGATE,
                         n0_max-n0_min,

Running the code from above now gives

Result: [(0,-2), (0,-2)]
Expected: [(0,-2), (0,-2)]

Let me know if you need more information or if a PR with the changes would make things easier for you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions