-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtorch_kde.py
More file actions
72 lines (62 loc) · 2.08 KB
/
torch_kde.py
File metadata and controls
72 lines (62 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from pykeops.torch import LazyTensor
class GaussianKDE(torch.nn.Module):
def __init__(self, data, weights=None, sigma=3):
r"""
Inputs:
:data: Tensor (N_points, N_dims)
:weights: Tensor (N_points)
:sigma: Gaussian scale
"""
super().__init__()
self.dpoints = data
if weights is not None:
self.weights = weights.unsqueeze(1)
else:
self.weights = None
self.sigma = sigma
def forward(self, x):
r"""
Apply the kde at the given locations.
Inputs:
:x: Tensor (B, N_locs, N_dims)
"""
a = LazyTensor(x.unsqueeze(0))
b = LazyTensor(self.dpoints.unsqueeze(1).unsqueeze(0).contiguous())
pairwise_ = (- (a-b).square().sum(-1) / self.sigma).exp() # (N_points, N_locs)
if self.weights is not None:
w_ = LazyTensor(self.weights.unsqueeze(0).unsqueeze(-1))
outp_ = (pairwise_ * w_).sum(0,1)
else:
outp_ = pairwise_.sum(0,1)
return outp_
class ParabolicKDE(torch.nn.Module):
def __init__(self, data, weights=None, sigma=3):
r"""
Inputs:
:data: Tensor (N_points, N_dims)
:weights: Tensor (N_points)
:sigma: scale
"""
super().__init__()
self.dpoints = data
if weights is not None:
self.weights = weights.unsqueeze(1)
else:
self.weights = None
self.sigma = sigma
def forward(self, x):
r"""
Apply the kde at the given locations.
Inputs:
:x: Tensor (B, N_locs, N_dims)
"""
a = LazyTensor(x.unsqueeze(0))
b = LazyTensor(self.dpoints.unsqueeze(1).unsqueeze(0).contiguous())
pairwise_ = ((a-b)/self.sigma).square().sum(-1) # (N_points, N_locs)
if self.weights is not None:
w_ = LazyTensor(self.weights.unsqueeze(0).unsqueeze(-1))
outp_ = (pairwise_ * w_).sum(0,1)
else:
outp_ = pairwise_.sum(0,1)
return outp_