-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathenergy.py
More file actions
109 lines (90 loc) · 5.03 KB
/
Copy pathenergy.py
File metadata and controls
109 lines (90 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import math
from torchtyping import TensorType
from typing import Dict, Any, Callable, Tuple
from siren.fim import NonparametricFIMDensity
def form_spatial_neighborhood_specification(model_fn : Callable[[TensorType["sim", "point", "coord"], TensorType["sim", "param"]], TensorType["sim", "point", "feat"]], ref_param : TensorType["param"], t : float, spatial_center : TensorType["space"], radius : float, n_samples : int, domain_transform = None) -> Callable[[TensorType["sim", "param"]], TensorType["sim"]]:
domain_query = torch.cat((spatial_center, torch.tensor([t], device = "cuda")), dim = 0).unsqueeze(0)
def random_neighbors():
# form sampling of spatial neighborhood
rand_dirs = 2 * torch.rand(n_samples, spatial_center.shape[0], device = "cuda")-1
rand_dirs /= rand_dirs.norm(dim = 1, keepdim = True)
rand_r = radius * torch.rand(n_samples, 1, device = "cuda")
spatial_neighborhood_pts = rand_r * rand_dirs
return torch.cat((spatial_neighborhood_pts, torch.zeros(n_samples, 1, device = "cuda")), dim = 1)
def form_neighborhood(points : TensorType["sim", "point"], neighborhood_stencil : TensorType["neighbors", "point"]):
return points.unsqueeze(1) + neighborhood_stencil.unsqueeze(0)
#
# specification function
def spec(P : TensorType["sim", "param"], normalized_domain_params = None) -> TensorType["sim"]:
neighborhood_pts = random_neighbors()
with torch.no_grad():
target = model_fn(form_neighborhood(domain_query, neighborhood_pts), ref_param.unsqueeze(0))
#
if normalized_domain_params is None:
feats = model_fn(form_neighborhood(domain_query * torch.ones(P.shape[0], 1, device = "cuda"), neighborhood_pts), P)
return -((feats - target)**2).sum(dim = 2).sqrt().mean(dim = 1)
# return -((feats - target)**2).sum(dim = 2).mean(dim = 1)
else:
domain_params = domain_transform(normalized_domain_params)
feats = model_fn(form_neighborhood(domain_params, neighborhood_pts), P)
return -((feats - target)**2).sum(dim = 2).sqrt().mean(dim = 1)
# return -((feats - target)**2).sum(dim = 2).mean(dim = 1)
#
#
return spec
#
def form_prior(density : NonparametricFIMDensity, param_bw : float, fim_bw : float) -> Callable[[TensorType["chain", "param"], float], TensorType["chain"]]:
def prior(P : TensorType["sim", "param"]) -> TensorType["sim"]:
return density.log_density_estimate(P, param_bw, fim_bw)
#
return prior
#
def form_energy(prior : Callable[[TensorType["chain", "param"], float], TensorType["chain"]], spec : Callable[[TensorType["chain", "param"], float], TensorType["chain"]], spec_weight : float) -> Callable[[TensorType["chain", "param"], float, bool], TensorType["chain"]]:
def energy(P : TensorType["sim", "param"], grad : bool = False, apply_weight : bool = True, t : float = 1.0) -> TensorType["sim"]:
actual_weight = spec_weight if apply_weight else 1.0
if grad:
P.requires_grad_()
# hinge_P = 50 * torch.maximum(P.abs() - .95, torch.zeros_like(P))
if spec == None:
hinge_P = (P**2) / 2 / 20**2
else:
hinge_P = (P**2) / 2 / .8**2
if spec == None:
energy = t * prior(P) - hinge_P.sum(dim = -1)
else:
energy = t * prior(P) + t * actual_weight * spec(P) - hinge_P.sum(dim = -1)
energy.backward(torch.ones(P.shape[0], device = "cuda"))
energy_grad = P.grad
energy.detach_()
P.grad = None
P.detach_()
return energy, energy_grad
#
return t * prior(P) + t * actual_weight * spec(P) if spec is not None else t * prior(P)
#
return energy
#
def form_domain_energy(prior : Callable[[TensorType["chain", "param"], float], TensorType["chain"]], spec : Callable[[TensorType["chain", "param"], float], TensorType["chain"]], spec_weight : float) -> Callable[[TensorType["chain", "param"], float, bool], TensorType["chain"]]:
def energy(P : TensorType["sim", "param"], X : TensorType["sim", "point"], grad : bool = False, apply_weight : bool = True) -> TensorType["sim"]:
actual_weight = spec_weight if apply_weight else 1.0
if grad:
P.requires_grad_()
X.requires_grad_()
hinge_X = 5 * torch.maximum(X.abs() - 1, torch.zeros_like(X))
hinge_P = 5 * torch.maximum(P.abs() - 1, torch.zeros_like(P))
energy = prior(P) + actual_weight * spec(P, normalized_domain_params = X) - hinge_X.sum(dim = -1) - hinge_P.sum(dim = -1)
energy.backward(torch.ones(P.shape[0], device = "cuda"))
P_grad = P.grad
P.grad = None
P.detach_()
X_grad = X.grad
X.grad = None
X.detach_()
energy.detach_()
return energy, P_grad, X_grad
#
return prior(P) + actual_weight * spec(P, normalized_domain_params = X)
#
return energy
#