Skip to content
29 changes: 29 additions & 0 deletions include/gauge_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ namespace quda {
*/
virtual void copy(const GaugeField &src) = 0;

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
virtual void shift(const GaugeField &src, const int *dx) = 0;

/**
@brief Compute the L1 norm of the field
@param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions)
Expand Down Expand Up @@ -543,6 +550,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const int *dx);

/**
@brief Download into this field from a CPU field
@param[in] cpu The CPU field source
Expand Down Expand Up @@ -680,6 +694,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const int *dx);

void* Gauge_p() { return gauge; }
const void* Gauge_p() const { return gauge; }

Expand Down Expand Up @@ -872,4 +893,12 @@ namespace quda {

#define checkReconstruct(...) Reconstruct_(__func__, __FILE__, __LINE__, __VA_ARGS__)

/**
* @brief Generic gauge field shift
* @param[out] dst Gauge field to store output
* @param[in] srd Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void gaugeShift(GaugeField &dst, const GaugeField &src, const int *dx);

} // namespace quda
78 changes: 78 additions & 0 deletions include/kernels/gauge_shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#pragma once

#include <gauge_field_order.h>
#include <quda_matrix.h>
#include <index_helper.cuh>
#include <kernel.h>

namespace quda {

template <typename Float_, int nColor_, QudaReconstructType recon_u>
struct GaugeShiftArg : kernel_param<> {
using Float = Float_;
static constexpr int nColor = nColor_;
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
typedef typename gauge_mapper<Float,recon_u>::type Gauge;

Gauge out;
const Gauge in;
int geometry;

int S[4]; // the regular volume parameters
int X[4]; // the regular volume parameters
int E[4]; // the extended volume parameters
int border[4]; // radius of border
int P; // change of parity

GaugeShiftArg(GaugeField &out, const GaugeField &in, const int* dx) :
kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())),
out(out),
in(in),
geometry(in.Geometry())
{
P = 0;
for (int i=0; i<4; i++) {
S[i] = dx[i];
X[i] = out.X()[i];
E[i] = in.X()[i];
border[i] = (E[i] - X[i])/2;
P += dx[i];
}
P = std::abs(P)%2;
}
};

template <typename Arg, int dir>
__device__ __host__ inline void GaugeShiftKernel(const Arg &arg, int idx, int parity)
{
using real = typename Arg::Float;
typedef Matrix<complex<real>,Arg::nColor> Link;

int x[4] = {0, 0, 0, 0};
getCoords(x, idx, arg.X, parity);
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
int nbr_oddbit = arg.P==1 ? (parity^1) : parity;

Link link = arg.in(dir, linkIndexShift(x,arg.S,arg.E), nbr_oddbit);
arg.out(dir, idx, parity) = link;
}

template <typename Arg> struct GaugeShift
{
const Arg &arg;
constexpr GaugeShift(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
if(dir>=arg.geometry) return;
switch(dir) {
case 0: GaugeShiftKernel<Arg,0>(arg, x_cb, parity); break;
case 1: GaugeShiftKernel<Arg,1>(arg, x_cb, parity); break;
case 2: GaugeShiftKernel<Arg,2>(arg, x_cb, parity); break;
case 3: GaugeShiftKernel<Arg,3>(arg, x_cb, parity); break;
}
}
};

}
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ set (QUDA_OBJS
copy_gauge_half.cu copy_gauge_quarter.cu
copy_gauge.cpp copy_gauge_mg.cu copy_clover.cu
copy_gauge_offset.cu copy_color_spinor_offset.cu copy_clover_offset.cu
gauge_shift.cu
staggered_oprod.cu clover_trace_quda.cu
hisq_paths_force_quda.cu
unitarize_force_quda.cu unitarize_links_quda.cu milc_interface.cpp
Expand Down
19 changes: 19 additions & 0 deletions lib/cpu_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,25 @@ namespace quda {
}
}

void cpuGaugeField::shift(const GaugeField &src, const int *dx) {
for(int i=0; i<this->nDim; i++) {
if (dx[i]!=0) break;
// if zero shift, we simply copy
if (i == this->nDim-1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
errorQuda("Not Implemented");
} else {
errorQuda("Not compatible type");
}
}

void cpuGaugeField::setGauge(void **gauge_)
{
if(create != QUDA_REFERENCE_FIELD_CREATE) {
Expand Down
18 changes: 18 additions & 0 deletions lib/cuda_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,24 @@ namespace quda {
qudaDeviceSynchronize(); // include sync here for accurate host-device profiling
}

void cudaGaugeField::shift(const GaugeField &src, const int *dx) {
for(int i=0; i<this->nDim; i++) {
if (dx[i]!=0) break;
if (i == this->nDim-1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
gaugeShift(*this, src, dx);
} else {
errorQuda("Not compatible type");
}
}

void cudaGaugeField::loadCPUField(const cpuGaugeField &cpu) {
copy(cpu);
qudaDeviceSynchronize();
Expand Down
57 changes: 57 additions & 0 deletions lib/gauge_shift.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include <tunable_nd.h>
#include <instantiate.h>
#include <gauge_field.h>
#include <kernels/gauge_shift.cuh>

namespace quda {

template <typename Float, int nColor, QudaReconstructType recon_u> class ShiftGauge : public TunableKernel3D
{
GaugeField &out;
const GaugeField &in;
const int* dx;
unsigned int minThreads() const { return in.VolumeCB(); }

public:
ShiftGauge(GaugeField &out, const GaugeField &in, const int * dx) :
TunableKernel3D(in, 2, in.Geometry()),
out(out),
in(in),
dx(dx)
{
strcat(aux, ",shift=");
for(int i=0; i<in.Ndim(); i++) {
strcat(aux, std::to_string(dx[i]).c_str());
}
strcat(aux, comm_dim_partitioned_string());
apply(device::get_default_stream());
}

void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
launch<GaugeShift>(tp, stream, GaugeShiftArg<Float, nColor, recon_u>(out, in, dx));
}

void preTune() { }
void postTune() { }

long long flops() const { return in.Volume() * 4; }
long long bytes() const { return in.Bytes(); }
};

void gaugeShift(GaugeField& out, const GaugeField& in, const int *dx)
{
checkPrecision(in, out);
checkLocation(in, out);
checkReconstruct(in, out);

if (out.Geometry() != in.Geometry()) {
errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
}

// gauge field must be passed as first argument so we peel off its reconstruct type
instantiate<ShiftGauge>(out, in, dx);
}

} // namespace quda