Skip to content

Commit 10af7d4

Browse files
committed
Initial commit.
1 parent 80badf4 commit 10af7d4

7 files changed

+354
-0
lines changed

Makefile

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
ALLIBS := $(patsubst %_module.cc, %.so, $(wildcard *_module.cc))
2+
3+
all: $(ALLIBS)
4+
5+
%.so: %_kernel.o %_module.o
6+
g++ -std=c++14 -shared -o $@ $^ $(TF_OPS_CFLAGS) -fPIC $(TF_OPS_LFLAGS)
7+
ln -sf compiled/$@ ../$@
8+
9+
%_module.o: %_module.cc
10+
g++ -std=c++14 -c -o $@ $< $(TF_OPS_CFLAGS) -fPIC
11+
12+
%_kernel.o: %_kernel.cc
13+
g++ -std=c++14 -c -o $@ $< $(TF_OPS_CFLAGS) -fPIC
14+
15+
clean:
16+
rm -f $(ALLIBS)
17+
18+
.PRECIOUS: %.o

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# tensorflow-custom-ops
2+
23
Custom TensorFlow Ops for use in CMSSW
4+
5+
**Currently in development**

accknn_op.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
2+
import tensorflow as tf
3+
from tensorflow.python.framework import ops
4+
import globals as gl
5+
from oc_helper_ops import SelectWithDefault
6+
7+
'''
8+
Indices MUST be unique in each row.
9+
Only exception are multiple self-references, that can be used as sort of padding.
10+
Alternatively, the index -1 is skipped (non TF conpatible padding)
11+
12+
'''
13+
14+
_accknn_op = tf.load_op_library('accumulate_knn.so')
15+
_accknn_grad_op = tf.load_op_library('accumulate_knn_grad.so')
16+
17+
18+
def AccumulateLinKnn(weights, features, indices,
19+
mean_and_max=True, force_tf=False):
20+
'''
21+
Accumulates neighbour features with linear weights (not exp(-w) as AccumulateKnn)
22+
'''
23+
if (not gl.acc_ops_use_tf_gradients) and (not force_tf):
24+
return _accknn_op.AccumulateKnn(distances=weights, features=features, indices=indices,
25+
n_moments=0, mean_and_max=mean_and_max)
26+
27+
28+
weights = tf.expand_dims(weights,axis=2) #V x K x 1
29+
nfeat = SelectWithDefault(indices, features, 0.) # V x K x F
30+
wfeat = weights*nfeat
31+
fmean = tf.reduce_mean(wfeat,axis=1)# V x F
32+
fmax = tf.reduce_max(wfeat,axis=1)
33+
fout = fmean
34+
if mean_and_max:
35+
fout = tf.concat([fmean,fmax],axis=1)
36+
return fout,None
37+
38+
39+
def AccumulateKnn(distances, features, indices,
40+
mean_and_max=True,force_tf=False):
41+
'''
42+
43+
.Output("out_features: float32")
44+
.Output("out_max_idxs: int32");
45+
46+
47+
Assumes that neighbour indices can be padded with -1, but not mixed, e.g. [1,4,-1,2] needs to be [1,4,2,-1]
48+
Other than the padding, the indices must be unique
49+
50+
'''
51+
#compatibility
52+
distances = tf.exp(-distances)
53+
54+
55+
if (not gl.acc_ops_use_tf_gradients) and (not force_tf):
56+
return _accknn_op.AccumulateKnn(distances=distances, features=features, indices=indices,
57+
n_moments=0, mean_and_max=mean_and_max)
58+
59+
60+
distances = tf.expand_dims(distances,axis=2) #V x K x 1
61+
nfeat = SelectWithDefault(indices, features, 0.) # V x K x F
62+
wfeat = distances*nfeat
63+
fmean = tf.reduce_mean(wfeat,axis=1)# V x F
64+
fmax = tf.reduce_max(wfeat,axis=1)
65+
fout = fmean
66+
if mean_and_max:
67+
fout = tf.concat([fmean,fmax],axis=1)
68+
return fout,None
69+
70+
#this refers to the OP called AccumulateKnn, not the function below
71+
@ops.RegisterGradient("AccumulateKnn")
72+
def _AccumulateKnnGrad(op, grad, gradmaxidxs):
73+
"""
74+
75+
"""
76+
77+
78+
distances = op.inputs[0]
79+
features = op.inputs[1]
80+
max_feat_indices = op.outputs[1]
81+
neigh_indices = op.inputs[2]
82+
83+
dist_grad , feat_grad = _accknn_grad_op.AccumulateKnnGrad(grad_from_out_features=grad,
84+
distances=distances,
85+
features=features,
86+
neigh_indices=neigh_indices,
87+
max_feat_indices=max_feat_indices)
88+
89+
return [dist_grad , feat_grad, None] #no gradient for indices
90+

accumulate_knn_kernel.cc

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
2+
#if GOOGLE_CUDA
3+
#define EIGEN_USE_GPU
4+
#endif // GOOGLE_CUDA
5+
6+
7+
#include "tensorflow/core/framework/op_kernel.h"
8+
#include "accumulate_knn_kernel.h"
9+
#include "helpers.h"
10+
#include <string> //size_t, just for helper function
11+
#include <cmath>
12+
13+
#include <iostream> //remove later DEBUG FIXME
14+
15+
namespace tensorflow {
16+
typedef Eigen::ThreadPoolDevice CPUDevice;
17+
typedef Eigen::GpuDevice GPUDevice;
18+
19+
namespace functor {
20+
21+
22+
static inline float distanceWeight(const float& distsq){
23+
return distsq;
24+
}
25+
26+
// CPU specialization
27+
template<typename dummy>
28+
struct AccumulateKnnOpFunctor<CPUDevice, dummy> {
29+
void operator()(const CPUDevice &d,
30+
31+
const float *d_distances,
32+
const float *d_feat,
33+
const int *d_idxs,
34+
35+
float *d_out_feat,
36+
int *d_out_maxidxs,
37+
38+
int n_vert,
39+
int n_neigh,
40+
int n_feat,
41+
42+
int n_out_feat,
43+
44+
int n_moments,
45+
bool mean_and_max) {
46+
47+
48+
for (size_t i_v = 0; i_v < n_vert; i_v++) {
49+
50+
for(size_t i_f=0;i_f<n_feat;i_f++){
51+
float t_mean = 0;
52+
float t_max = 0;
53+
int max_i_n_gidx = 0;
54+
55+
for(size_t i_n=0;i_n<n_neigh;i_n++){
56+
int nidx = d_idxs[I2D(i_v,i_n,n_neigh)];
57+
58+
if(nidx<0) continue;
59+
60+
float vnf = d_feat[I2D(nidx,i_f,n_feat)];
61+
float distsq = d_distances[I2D(i_v,i_n,n_neigh)];
62+
float wfeat = vnf * distanceWeight(distsq);
63+
//DEBUGCOUT(wfeat);
64+
t_mean += wfeat;
65+
if(mean_and_max && (wfeat >= t_max || !i_n)){
66+
max_i_n_gidx = nidx;
67+
t_max = wfeat;
68+
}
69+
}
70+
t_mean /= (float)n_neigh;
71+
72+
d_out_feat[I2D(i_v,i_f,n_out_feat)] = t_mean;
73+
if(mean_and_max){
74+
d_out_maxidxs[I2D(i_v,i_f,n_feat)] = max_i_n_gidx; //just used for gradient
75+
d_out_feat[I2D(i_v,i_f+n_feat,n_out_feat)] = t_max;
76+
}
77+
//moments in n_coords x n_neigh loop here {}
78+
79+
}
80+
81+
}
82+
}
83+
};
84+
85+
template<typename Device>
86+
class AccumulateKnnOp : public OpKernel {
87+
public:
88+
explicit AccumulateKnnOp(OpKernelConstruction *context) : OpKernel(context) {
89+
OP_REQUIRES_OK(context,
90+
context->GetAttr("n_moments", &n_moments));
91+
OP_REQUIRES_OK(context,
92+
context->GetAttr("mean_and_max", &mean_and_max));
93+
}
94+
95+
void Compute(OpKernelContext *context) override {
96+
97+
const Tensor &d_dist_tensor = context->input(0);
98+
const Tensor &d_feat_tensor = context->input(1);
99+
const Tensor &d_idxs_tensor = context->input(2);
100+
101+
102+
int n_vert = d_dist_tensor.dim_size(0);
103+
int n_neigh = d_idxs_tensor.dim_size(1);
104+
int n_coords = d_dist_tensor.dim_size(1);
105+
int n_feat = d_feat_tensor.dim_size(1);
106+
107+
OP_REQUIRES(context, n_vert == d_idxs_tensor.dim_size(0) && n_vert == d_feat_tensor.dim_size(0),
108+
errors::InvalidArgument("AccumulateKnnOp expects first dimensions of all inputs to match."));
109+
110+
OP_REQUIRES(context, n_neigh == d_dist_tensor.dim_size(1),
111+
errors::InvalidArgument("AccumulateKnnOp expects second dimension of distance and neighbour index tensor to match"));
112+
113+
int n_out_feat = n_feat; //mean and max
114+
if(mean_and_max)
115+
n_out_feat*=2;
116+
117+
// after testing basic functionality!
118+
// n_out_feat += n_moments * n_feat * n_coords;
119+
120+
121+
TensorShape outputShape;
122+
outputShape.AddDim(n_vert);
123+
outputShape.AddDim(n_out_feat);
124+
125+
Tensor *output_tensor = NULL;
126+
OP_REQUIRES_OK(context, context->allocate_output(0, outputShape, &output_tensor));
127+
128+
TensorShape outputShape_max_idxs;
129+
outputShape_max_idxs.AddDim(n_vert);
130+
outputShape_max_idxs.AddDim(n_feat);
131+
132+
Tensor *output_max_idxs_tensor = NULL;
133+
OP_REQUIRES_OK(context, context->allocate_output(1, outputShape_max_idxs, &output_max_idxs_tensor));
134+
135+
136+
AccumulateKnnOpFunctor<Device, int>()(
137+
context->eigen_device<Device>(),
138+
d_dist_tensor.flat<float>().data(),
139+
d_feat_tensor.flat<float>().data(),
140+
d_idxs_tensor.flat<int>().data(),
141+
output_tensor->flat<float>().data(),
142+
output_max_idxs_tensor->flat<int>().data(),
143+
n_vert,
144+
n_neigh,
145+
n_feat,
146+
n_out_feat,
147+
n_moments,
148+
mean_and_max
149+
);
150+
151+
152+
153+
}
154+
155+
private:
156+
int n_moments;
157+
bool mean_and_max;
158+
};
159+
160+
REGISTER_KERNEL_BUILDER(Name("AccumulateKnn").Device(DEVICE_CPU), AccumulateKnnOp<CPUDevice>);
161+
162+
#ifdef GOOGLE_CUDA
163+
//extern template struct AccumulateKnnOpFunctor<GPUDevice, int>;
164+
//REGISTER_KERNEL_BUILDER(Name("AccumulateKnn").Device(DEVICE_GPU), AccumulateKnnOp<GPUDevice>);
165+
#endif // GOOGLE_CUDA
166+
167+
}//functor
168+
}//tensorflow

accumulate_knn_kernel.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// accumulate_knn_kernel.h
2+
#ifndef ACCUMULATE_KNN_KERNEL_H
3+
#define ACCUMULATE_KNN_KERNEL_H
4+
5+
6+
namespace tensorflow {
7+
namespace functor {
8+
9+
template<typename Device, typename dummy>
10+
struct AccumulateKnnOpFunctor {
11+
void operator()(
12+
const Device &d,
13+
14+
const float *d_distances,
15+
const float *d_feat,
16+
const int *d_idxs,
17+
18+
float *d_out_feat,
19+
int *d_out_maxidxs,
20+
21+
int n_vert,
22+
int n_neigh,
23+
int n_feat,
24+
25+
int n_out_feat,
26+
27+
int n_moments,
28+
bool mean_and_max);
29+
};
30+
31+
32+
} // namespace functor
33+
} // namespace tensorflow
34+
35+
#endif //ACCUMULATE_KNN_KERNEL_H

accumulate_knn_module.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "tensorflow/core/framework/op.h"
2+
#include "tensorflow/core/framework/shape_inference.h"
3+
4+
using namespace tensorflow;
5+
6+
7+
REGISTER_OP("AccumulateKnn")
8+
.Attr("n_moments: int")
9+
.Attr("mean_and_max: bool")
10+
.Input("distances: float32") //change to distances!!
11+
.Input("features: float32")
12+
.Input("indices: int32")
13+
.Output("out_features: float32")
14+
.Output("out_max_idxs: int32");
15+
16+
17+

helpers.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* helpers.h
3+
*
4+
* Created on: 8 May 2020
5+
* Author: jkiesele
6+
*/
7+
8+
#ifndef HGCALML_MODULES_COMPILED_HELPERS_H_
9+
#define HGCALML_MODULES_COMPILED_HELPERS_H_
10+
11+
#include <iostream>
12+
13+
#define I2D(i,j,Nj) (j) + (Nj)*(i)
14+
#define I3D(i,j,k,Nj,Nk) (k) + (Nk)*((j) + (Nj)*(i))
15+
#define I4D(i,j,k,l,Nj,Nk,Nl) (l) + (Nl)*((k) + (Nk)*((j) + (Nj)*(i)))
16+
#define I5D(i,j,k,l,m,Nj,Nk,Nl,Nm) (m) + (Nm)*((l) + (Nl)*((k) + (Nk)*((j) + (Nj)*(i))))
17+
18+
19+
#define DEBUGCOUT(x) {std::cout << #x <<": " << x << std::endl;}
20+
21+
22+
23+
#endif /* HGCALML_MODULES_COMPILED_HELPERS_H_ */

0 commit comments

Comments
 (0)