Skip to content

Commit f323e60

Browse files
committed
implemented sampline via autoregressive inverse transform
1 parent 1fbaa0c commit f323e60

File tree

9 files changed

+497
-56
lines changed

9 files changed

+497
-56
lines changed

analysis/plot_helper.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,6 @@ def plot_in_diverging_theme(
165165
# plt.figure()
166166
# f = plt.gcf()
167167
# the_ax = plt.gca()
168-
if plot_max_value is None:
169-
the_max = data.max()
170-
else:
171-
the_max = plot_max_value
172-
173-
if norm is not None:
174-
the_max = norm.vmax
175-
176168
if with_contour and plot_borders_below:
177169
if hatched:
178170
contour = the_ax.contourf(
@@ -197,22 +189,32 @@ def plot_in_diverging_theme(
197189
origin="upper",
198190
zorder=0,
199191
)
200-
the_data = data.copy()
201-
202-
# negative values are valid, positve values are invalid
203-
# I know this is stupid, but it is what it is
204-
the_data[valid.T] = (-1) * the_data[valid.T]
205-
# the_data[~valid] = (-1) * the_data[~valid]
206-
207-
the_ax.imshow(
208-
the_data,
209-
extent=img_extent,
210-
origin="upper",
211-
cmap=the_cmap,
212-
vmin=-the_max,
213-
vmax=the_max,
214-
interpolation="nearest",
215-
)
192+
193+
if data is not None:
194+
if plot_max_value is None:
195+
the_max = data.max()
196+
else:
197+
the_max = plot_max_value
198+
199+
if norm is not None:
200+
the_max = norm.vmax
201+
202+
the_data = data.copy()
203+
204+
# negative values are valid, positve values are invalid
205+
# I know this is stupid, but it is what it is
206+
the_data[valid.T] = (-1) * the_data[valid.T]
207+
# the_data[~valid] = (-1) * the_data[~valid]
208+
209+
the_ax.imshow(
210+
the_data,
211+
extent=img_extent,
212+
origin="upper",
213+
cmap=the_cmap,
214+
vmin=-the_max,
215+
vmax=the_max,
216+
interpolation="nearest",
217+
)
216218

217219
if with_valid:
218220
# plot valid tensor with true being pink and false transparent

analysis/plot_samples.ipynb

Lines changed: 419 additions & 0 deletions
Large diffs are not rendered by default.

pal/distribution/sampling/inverse_transform.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def bisection_search(
2727
upper = mid
2828
count += 1
2929

30-
print(f"bisection_search took {count} iterations")
30+
# print(f"bisection_search took {count} iterations")
3131
return (upper + lower) / 2
3232

3333

@@ -53,7 +53,7 @@ def conditioned_function(
5353
"""
5454

5555
def wrapped_function(x: torch.Tensor) -> torch.Tensor:
56-
x_full = torch.cat([x, constant.unsqueeze(0)], dim=-1)
56+
x_full = torch.vmap(lambda x: torch.cat([constant, x]))(x)
5757
return f(x_full)
5858

5959
return wrapped_function
@@ -122,7 +122,7 @@ def condition_linear_ineq(linear_ineq: lra.LinearInequality):
122122
return linear_ineq_res
123123

124124
the_constraints = constraints.map_constraints(
125-
f=condition_linear_ineq, drop_vars=list(sampled_dimensions.keys())
125+
f=condition_linear_ineq, #, drop_vars=list(sampled_dimensions.keys())
126126
)
127127

128128
degree = get_sub_total_degree(poly, i)
@@ -137,7 +137,7 @@ def condition_linear_ineq(linear_ineq: lra.LinearInequality):
137137
lower, upper = limits[var]
138138

139139
constrained_var_map_dict = {
140-
var: (i - idx) for var, idx in variable_map.items() if i >= idx
140+
var: (i - idx) for var, idx in variable_map.items() if idx >= i
141141
}
142142

143143
integration_kwargs = {
@@ -183,7 +183,7 @@ def sample_spline_distribution(
183183
Sample a single point from the SplineSQ2D distribution.
184184
"""
185185
assert not dist.is_batched()
186-
mixture_weights = dist.mixture_weights.unsqueeze(0)
186+
mixture_weights = dist.mixture_weights.squeeze(0)
187187
# sample the component index from the mixture weights
188188
component_index = torch.multinomial(mixture_weights, 1, replacement=True).item()
189189

@@ -209,11 +209,19 @@ def sample_spline_distribution(
209209
f.to(device)
210210
f = f.to(precision)
211211

212-
return inverse_transform_sampling_gasp(
212+
p = inverse_transform_sampling_gasp(
213213
boxed_problem,
214214
f,
215215
device=device,
216216
precision=precision,
217217
gasp_kwargs=gasp_kwargs,
218218
wmi_pa_mode=wmi_pa_mode,
219219
)
220+
221+
# assert that in box
222+
variable_map = dist.var_positions
223+
for var, idx in variable_map.items():
224+
lb, ub = box.constraints[var]
225+
assert lb <= p[idx] <= ub, f"Sampled point {p[idx]} is out of bounds for {var} ({lb}, {ub})"
226+
227+
return p

pal/distribution/spline_distribution.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ class SplineSQ2D(ConstrainedDistribution, torch.nn.Module):
420420
differences: torch.Tensor # (num_knots-1, 2) because of memory layout
421421
integrals_2dgrid: (
422422
torch.Tensor
423-
) # (b, num_mixtures, num_knots, num_knots), b=1 gets broadcasted
423+
) # (b, num_mixtures, num_knots-1, num_knots-1), b=1 gets broadcasted
424424
poly_params: (
425425
torch.Tensor
426426
) # (b, num_mixtures, 2, num_knots, 4), b=1 gets broadcasted
@@ -552,15 +552,18 @@ def enumerate_pieces(
552552
) -> list[tuple[lra.Box, torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]:
553553
results = []
554554

555+
y_pos_dict = {i: name for name, i in self.var_positions.items()}
556+
y_pos_dict = y_pos_dict
557+
555558
for i in range(self.knots.shape[0] - 1):
556559
for j in range(self.knots.shape[0] - 1):
557560
lower_x0 = self.knots[i, 0]
558561
upper_x0 = self.knots[i + 1, 0]
559562
lower_x1 = self.knots[j, 1]
560563
upper_x1 = self.knots[j + 1, 1]
561564

562-
varname0 = self.y_pos_dict[0]
563-
varname1 = self.y_pos_dict[1]
565+
varname0 = y_pos_dict[0]
566+
varname1 = y_pos_dict[1]
564567

565568
box = lra.Box(
566569
id=(i, j),
@@ -570,7 +573,7 @@ def enumerate_pieces(
570573
},
571574
)
572575

573-
integrals = self.integrals_2dgrid[:, i, j]
576+
integrals = self.integrals_2dgrid[:, :, i, j]
574577
p1 = self.poly_params[:, :, 0, i]
575578
p2 = self.poly_params[:, :, 1, j]
576579

pal/distribution/torch_polynomial.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def __init__(
279279
VariableMapMixin.__init__(self, variable_map_dict)
280280
self.register_buffer("coeffs", coeffs)
281281
self.register_buffer("powers", powers)
282+
if params.shape[0] == 1:
283+
params = params.squeeze(0)
282284
self.register_buffer("params", params)
283285
assert coeffs.shape[0] == powers.shape[0]
284286

pal/logic/lra.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __and__(self, other: "LRA"):
126126
if isinstance(other, And):
127127
return And(*(self.children + other.children))
128128
else:
129-
return And(*(self.children + [other]))
129+
return And(*(list(self.children) + [other]))
130130

131131
def __str__(self) -> str:
132132
return "(" + " & ".join([str(child) for child in self.children]) + ")"
@@ -259,7 +259,8 @@ def to_logic(var: str, lb: float, ub: float):
259259
return self._expression
260260

261261
def map_constraints(
262-
self, f: Callable[[LinearInequality], LinearInequality]
262+
self,
263+
f: Callable[[LinearInequality], LinearInequality]
263264
) -> "LRAProblem":
264265
"""
265266
Applies a function `f` to each `LinearConstraint` in the expression tree while keeping
@@ -268,6 +269,8 @@ def map_constraints(
268269
Args:
269270
f (Callable[[LinearConstraint], LinearConstraint]): The function to apply to each
270271
`LinearConstraint`.
272+
drop_vars (list[str]): A list of variables to to drop as a result of the mapping.
273+
This is used for removing variables that are no longer needed after the mapping.
271274
272275
Returns:
273276
LinearIneqLogicTree: A new expression tree with the modified `LinearConstraint` objects.
@@ -285,15 +288,15 @@ def recurse_expression(expr: LRA) -> LRA:
285288
return Or(*mapped_children)
286289

287290
if self.expression is None:
288-
return LRAProblem(None, self._variables)
291+
return LRAProblem(None, self._variables, self._name)
289292
else:
290293
expr = recurse_expression(self._expression)
291294
variables = gather_variables(expr)
292295
if isinstance(self._variables, dict):
293296
sub_vars = {var: self._variables[var] for var in variables}
294297
else:
295298
sub_vars = [var for var in self._variables if var in variables]
296-
return LRAProblem(expr, sub_vars)
299+
return LRAProblem(expr, sub_vars, self._name)
297300

298301
def get_global_limits(self) -> dict[str, tuple[float, float]]:
299302
"""
@@ -309,9 +312,9 @@ def get_global_limits(self) -> dict[str, tuple[float, float]]:
309312
"Global limits are not available when variables are provided as a list."
310313
)
311314

312-
def __and__(self, other: LRA | Box):
315+
def __and__(self, other: LRA | Box) -> "LRAProblem":
313316
if isinstance(other, LRA):
314-
return LRAProblem(self.expression & other, self._variables)
317+
return LRAProblem(self.expression & other, self._variables, self._name)
315318
elif isinstance(other, Box):
316319
# via global bounds
317320
if isinstance(self._variables, dict):

pal/logic/lra_pysmt.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pal.logic.lra as lra
2-
from pysmt.shortcuts import get_env, Symbol, Real, Plus, And, Or
2+
from pysmt.shortcuts import get_env, Symbol, Real, Plus, And, Or, Bool
33
import pysmt.typing as stypes
44
from pysmt.fnode import FNode
55

@@ -25,19 +25,23 @@ def recursive_translate(
2525
node: lra.LRA,
2626
) -> FNode:
2727
if isinstance(node, lra.LinearInequality):
28-
# canonical is ax + by + cz + ... < d
29-
left_sum = Plus(
30-
*[
31-
Real(lhs_coeff) * get_symb(lhs_var)
32-
for lhs_var, lhs_coeff in node.lhs.items()
33-
]
34-
)
35-
if node.symbol == "<=":
36-
return left_sum <= Real(node.rhs)
37-
elif node.symbol == ">=":
38-
return left_sum >= Real(node.rhs)
28+
if not node.is_empty():
29+
# canonical is ax + by + cz + ... < d
30+
left_sum = Plus(
31+
*[
32+
Real(lhs_coeff) * get_symb(lhs_var)
33+
for lhs_var, lhs_coeff in node.lhs.items()
34+
]
35+
)
36+
if node.symbol == "<=":
37+
return left_sum <= Real(node.rhs)
38+
elif node.symbol == ">=":
39+
return left_sum >= Real(node.rhs)
40+
else:
41+
raise NotImplementedError()
3942
else:
40-
raise NotImplementedError()
43+
# empty constraints are always true
44+
return Bool(True)
4145
elif isinstance(node, lra.And):
4246
recursive_translate_children = [
4347
recursive_translate(c) for c in node.children

pal/wmi/compute_integral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from pysmt.fnode import FNode
55
import torch
66

7+
from pal.logic.lra import box_to_lra
78
from pal.logic.lra_pysmt import translate_to_pysmt
89
from pal.distribution.constrained_distribution import (
910
ConditionalConstraintedDistribution,
1011
ConstrainedDistributionBuilder,
11-
box_to_lra,
1212
)
1313
from pal.wmi.gasp.gasp.torch.wmipa.numerical_symb_integrator_pa import (
1414
FunctionMode,

pal/wmi/gasp/gasp/torch/wmipa/numerical_symb_integrator_pa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def convert_to_problem_wrapper(problem):
316316
if len(results) == 0:
317317
return torch.zeros([1, len(problems)]), 0
318318
else:
319-
if any(r.shape[0] == 1 for r in results):
320-
print("hi")
319+
# if any(r.shape[0] == 1 for r in results):
320+
# print("hi")
321321
results = torch.concatenate(results, dim=-1)
322322
return results, 0
323323
case WeightedFormulaMode():

0 commit comments

Comments
 (0)