55from .utils import unroll_ixs
66
77
8- def _create_layers (sum_layer , prod_layer , ixs_in , ixs_out ):
8+ def _create_layers (sum_layer , prod_layer , ixs_in , ixs_out , eps ):
99 layers = []
1010 for i , (ix_in , ix_out ) in enumerate (zip (ixs_in , ixs_out )):
1111 ix_in = torch .as_tensor (ix_in , dtype = torch .long )
1212 ix_out = torch .as_tensor (ix_out , dtype = torch .long )
1313 ix_out = unroll_ixs (ix_out )
1414 layer = prod_layer if i % 2 == 0 else sum_layer
15- layers .append (layer (ix_in , ix_out ))
15+ layers .append (layer (ix_in , ix_out , eps ))
1616 return nn .Sequential (* layers )
1717
1818
1919class CircuitModule (nn .Module ):
20- def __init__ (self , ixs_in , ixs_out , semiring = 'real' ):
20+ def __init__ (self , ixs_in , ixs_out , semiring : str = 'real' , eps : float = 0 ):
2121 super (CircuitModule , self ).__init__ ()
2222 self .semiring = semiring
23+ self ._eps = 0
24+
2325 self .sum_layer , self .prod_layer , self .zero , self .one , self .negate = \
2426 get_semiring (semiring , self .is_probabilistic ())
25- self .layers = _create_layers (self .sum_layer , self .prod_layer , ixs_in , ixs_out )
27+ self .layers = _create_layers (self .sum_layer , self .prod_layer , ixs_in , ixs_out , eps )
2628
27- def forward (self , x_pos , x_neg = None , eps = 0 ):
28- x = self .encode_input (x_pos , x_neg , eps )
29+ def forward (self , x_pos , x_neg = None ):
30+ x = self .encode_input (x_pos , x_neg )
2931 return self .layers (x )
3032
31- def encode_input (self , pos , neg , eps ):
33+ def encode_input (self , pos , neg ):
3234 if neg is None :
33- neg = self .negate (pos , eps )
35+ neg = self .negate (pos , self . _eps )
3436 x = torch .stack ([pos , neg ], dim = 1 ).flatten ()
35- units = torch .tensor ([self .zero , self .one ], dtype = torch . float32 , device = pos .device )
37+ units = torch .tensor ([self .zero , self .one ], dtype = pos . dtype , device = pos .device )
3638 return torch .cat ([units , x ])
3739
3840 def sparsity (self , nb_vars : int ) -> float :
39- sparse_params = sum (len (l .ix_out ) for l in self .layers )
40- layer_widths = [nb_vars ] + [l .out_shape [0 ] for l in self .layers ]
41+ sparse_params = sum (len (layer .ix_out ) for layer in self .layers )
42+ layer_widths = [nb_vars ] + [layer .out_shape [0 ] for layer in self .layers ]
4143 dense_params = sum (layer_widths [i ] * layer_widths [i + 1 ] for i in range (len (layer_widths ) - 1 ))
4244 return sparse_params / dense_params
4345
44- def to_pc (self , x_pos , x_neg = None , eps = 0 ):
46+ def to_pc (self , x_pos , x_neg = None ):
4547 """ Converts the circuit into a probabilistic circuit."""
4648 assert self .semiring == "log" or self .semiring == "real"
4749 pc = ProbabilisticCircuitModule ([], [], self .semiring )
48- print ("Making PC" , pc .sum_layer , pc .sum_layer )
4950 layers = []
5051
51- x = self .encode_input (x_pos , x_neg , eps )
52+ x = self .encode_input (x_pos , x_neg )
5253 for i , layer in enumerate (self .layers ):
5354 if isinstance (layer , self .sum_layer ):
54- new_layer = pc .sum_layer (layer .ix_in , layer .ix_out )
55+ new_layer = pc .sum_layer (layer .ix_in , layer .ix_out , layer . _eps )
5556 weights = x .log () if self .semiring == "real" else x
5657 new_layer .weights .data = weights [new_layer .ix_in ]
5758 else :
@@ -76,7 +77,7 @@ def sample(self):
7677 return y [2 ::2 ]
7778
7879 def condition (self , x_pos , x_neg ):
79- x = self .encode_input (x_pos , x_neg , None )
80+ x = self .encode_input (x_pos , x_neg )
8081 for layer in self .layers :
8182 x = layer .condition (x ) \
8283 if isinstance (layer , ProbabilisticCircuitLayer ) \
0 commit comments