11import math
2- from functools import reduce
32
43import torch
54from torch import nn
@@ -97,9 +96,9 @@ def _safe_exp(self, x: torch.Tensor):
9796 x .nan_to_num_ (nan = 0. , posinf = float ('inf' ), neginf = float ('-inf' ))
9897 return torch .exp (x ), max_output
9998
100- def _logsumexp_scatter (self , x : torch .Tensor , epsilon : float ):
99+ def _logsumexp_scatter (self , x : torch .Tensor , eps : float ):
101100 x , max_output = self ._safe_exp (x )
102- output = torch .full (self .out_shape , epsilon , dtype = x .dtype , device = x .device )
101+ output = torch .full (self .out_shape , eps , dtype = x .dtype , device = x .device )
103102 output = torch .scatter_add (output , 0 , index = self .csr , src = x )
104103 output = torch .log (output ) + max_output
105104 return output
@@ -111,6 +110,24 @@ def __init__(self, ptrs, csr):
111110 super ().__init__ (ptrs , csr )
112111 self .weights = nn .Parameter (torch .randn_like (ptrs , dtype = torch .float32 ))
113112
113+ def get_edge_weights (self ):
114+ exp_weights , _ = self ._safe_exp (self .weights )
115+ norm = self ._scatter_forward (exp_weights , "sum" )
116+ return exp_weights / norm [self .csr ]
117+
118+ def get_log_edge_weights (self , eps ):
119+ norm = self ._logsumexp_scatter (self .weights , eps )
120+ return self .weights - norm [self .csr ]
121+
122+ def sample_pc (self , y , eps = 10e-16 ):
123+ weights = self .get_log_edge_weights (eps )
124+ noise = - (- torch .log (torch .rand_like (weights ) + eps ) + eps ).log ()
125+ gumbels = weights + noise
126+ samples = self ._scatter_forward (gumbels , "amax" )
127+ samples = samples [self .csr ] == gumbels
128+ samples &= y [self .csr ].to (torch .bool )
129+ return self ._scatter_backward (samples , "sum" ) > 0
130+
114131
115132class SumLayer (KnowledgeLayer ):
116133 def forward (self , x ):
@@ -124,6 +141,9 @@ class ProdLayer(KnowledgeLayer):
124141 def forward (self , x ):
125142 return self ._scatter_forward (x [self .ptrs ], "prod" )
126143
144+ def sample_pc (self , y ):
145+ return self ._scatter_backward (y [self .csr ], "sum" ) > 0
146+
127147
128148class MinLayer (KnowledgeLayer ):
129149 def forward (self , x ):
@@ -136,37 +156,20 @@ def forward(self, x):
136156
137157
138158class LogSumLayer (KnowledgeLayer ):
139- def forward (self , x , epsilon = 10e-16 ):
140- return self ._logsumexp_scatter (x [self .ptrs ], epsilon )
159+ def forward (self , x , eps = 10e-16 ):
160+ return self ._logsumexp_scatter (x [self .ptrs ], eps )
141161
142162
143163class ProbabilisticSumLayer (ProbabilisticKnowledgeLayer ):
144164 def forward (self , x ):
145165 x = self .get_edge_weights () * x [self .ptrs ]
146166 return self ._scatter_forward (x , "sum" )
147167
148- def get_edge_weights (self ):
149- exp_weights , _ = self ._safe_exp (self .weights )
150- norm = self ._scatter_forward (exp_weights , "sum" )
151- return exp_weights / norm [self .csr ]
152-
153168
154169class ProbabilisticLogSumLayer (ProbabilisticKnowledgeLayer ):
155- def forward (self , x , epsilon = 10e-16 ):
156- x = self .get_edge_weights (epsilon ) + x [self .ptrs ]
157- return self ._logsumexp_scatter (x , epsilon )
158-
159- def get_edge_weights (self , epsilon ):
160- norm = self ._logsumexp_scatter (self .weights , epsilon )
161- return self .weights - norm [self .csr ]
162-
163- def sample_pc (self , y , epsilon = 10e-16 ):
164- weights = self .get_edge_weights (epsilon )
165- gumbels = weights - (- torch .rand_like (weights ).log ()).log ()
166- samples = self ._scatter_forward (gumbels , "amax" )
167- samples = samples [self .csr ] == gumbels
168- samples &= y [self .csr ].to (torch .bool )
169- return self ._scatter_backward (samples , "sum" ) > 0
170+ def forward (self , x , eps = 10e-16 ):
171+ x = self .get_log_edge_weights (eps ) + x [self .ptrs ]
172+ return self ._logsumexp_scatter (x , eps )
170173
171174
172175def get_semiring (name : str , probabilistic : bool ):
0 commit comments