11import math
22
33import torch
4+ from torch import nn
45
56CUTOFF = - math .log (2 )
67
@@ -33,11 +34,11 @@ def unroll_csr(csr):
3334 return ixs .repeat_interleave (repeats = deltas )
3435
3536
36- class KnowledgeModule (torch . nn .Module ):
37- def __init__ (self , pointers , csrs , semiring = 'real' ):
37+ class KnowledgeModule (nn .Module ):
38+ def __init__ (self , pointers , csrs , semiring = 'real' , probabilistic = False ):
3839 super (KnowledgeModule , self ).__init__ ()
3940 layers = []
40- sum_layer , prod_layer , self .zero , self .one , self .negate = get_semiring (semiring )
41+ sum_layer , prod_layer , self .zero , self .one , self .negate = get_semiring (semiring , probabilistic )
4142 for i , (ptrs , csr ) in enumerate (zip (pointers , csrs )):
4243 ptrs = torch .as_tensor (ptrs )
4344 csr = torch .as_tensor (csr , dtype = torch .long )
@@ -46,85 +47,118 @@ def __init__(self, pointers, csrs, semiring='real'):
4647 layers .append (prod_layer (ptrs , csr ))
4748 else :
4849 layers .append (sum_layer (ptrs , csr ))
49- self .layers = torch . nn .Sequential (* layers )
50+ self .layers = nn .Sequential (* layers )
5051
5152 def forward (self , weights , neg_weights = None , eps = 0 ):
5253 if neg_weights is None :
5354 neg_weights = self .negate (weights , eps )
5455 x = encode_input (weights , neg_weights , self .zero , self .one )
5556 return self .layers (x )
5657
57- def sparsity (self , nb_vars ) :
58+ def sparsity (self , nb_vars : int ) -> float :
5859 sparse_params = sum (len (l .csr ) for l in self .layers )
5960 layer_widths = [nb_vars ] + [l .out_shape [0 ] for l in self .layers ]
6061 dense_params = sum (layer_widths [i ] * layer_widths [i + 1 ] for i in range (len (layer_widths ) - 1 ))
6162 return sparse_params / dense_params
6263
6364
64- class KnowledgeLayer (torch . nn .Module ):
65+ class KnowledgeLayer (nn .Module ):
6566 def __init__ (self , ptrs , csr ):
6667 super ().__init__ ()
6768 self .register_buffer ('ptrs' , ptrs )
6869 self .register_buffer ('csr' , csr )
6970 self .out_shape = (self .csr [- 1 ].item () + 1 ,)
7071
72+ def _scatter_reduce (self , src : torch .Tensor , reduce : str ):
73+ output = torch .empty (self .out_shape , dtype = src .dtype , device = src .device )
74+ output = torch .scatter_reduce (output , 0 , index = self .csr , src = src , reduce = reduce , include_self = False )
75+ return output
76+
77+ def _logsumexp_scatter_reduce (self , x : torch .Tensor , epsilon : float ):
78+ with torch .no_grad ():
79+ max_output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
80+ max_output = torch .scatter_reduce (max_output , 0 , index = self .csr , src = x , reduce = "amax" , include_self = False )
81+ x = x - max_output [self .csr ]
82+ x .nan_to_num_ (nan = 0. , posinf = float ('inf' ), neginf = float ('-inf' ))
83+ x = torch .exp (x )
84+
85+ output = torch .full (self .out_shape , epsilon , dtype = x .dtype , device = x .device )
86+ output = torch .scatter_add (output , 0 , index = self .csr , src = x )
87+ output = torch .log (output ) + max_output
88+ return output
89+
90+
91+
92+ class ProbabilisticKnowledgeLayer (KnowledgeLayer ):
93+ def __init__ (self , ptrs , csr ):
94+ super ().__init__ (ptrs , csr )
95+ self .weights = nn .Parameter (torch .randn_like (ptrs ))
96+
7197
7298class SumLayer (KnowledgeLayer ):
7399 def forward (self , x ):
74- output = torch .zeros (self .out_shape , dtype = x .dtype , device = x .device )
75- output = torch .scatter_add (output , 0 , index = self .csr , src = x [self .ptrs ])
76- return output
100+ return self ._scatter_reduce (x [self .ptrs ], "sum" )
77101
78102
79103class ProdLayer (KnowledgeLayer ):
80104 def forward (self , x ):
81- output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
82- output = torch .scatter_reduce (output , 0 , index = self .csr , src = x [self .ptrs ], reduce = "prod" , include_self = False )
83- return output
105+ return self ._scatter_reduce (x [self .ptrs ], "prod" )
84106
85107
86108class MinLayer (KnowledgeLayer ):
87109 def forward (self , x ):
88- output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
89- output = torch .scatter_reduce (output , 0 , index = self .csr , src = x [self .ptrs ], reduce = "amin" , include_self = False )
90- return output
110+ return self ._scatter_reduce (x [self .ptrs ], "amin" )
91111
92112
93113class MaxLayer (KnowledgeLayer ):
94114 def forward (self , x ):
95- output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
96- output = torch .scatter_reduce (output , 0 , index = self .csr , src = x [self .ptrs ], reduce = "amax" , include_self = False )
97- return output
115+ return self ._scatter_reduce (x [self .ptrs ], "amax" )
98116
99117
100118class LogSumLayer (KnowledgeLayer ):
101119 def forward (self , x , epsilon = 10e-16 ):
102- x = x [self .ptrs ]
103- with torch .no_grad ():
104- max_output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
105- max_output = torch .scatter_reduce (max_output , 0 , index = self .csr , src = x , reduce = "amax" , include_self = False )
106- x = x - max_output [self .csr ]
107- x .nan_to_num_ (nan = 0. , posinf = float ('inf' ), neginf = float ('-inf' ))
108- x = torch .exp (x )
120+ return self ._logsumexp_scatter_reduce (x [self .ptrs ], epsilon )
109121
110- output = torch .full (self .out_shape , epsilon , dtype = x .dtype , device = x .device )
111- output = torch .scatter_add (output , 0 , index = self .csr , src = x )
112- output = torch .log (output ) + max_output
113- return output
122+
123+ class ProbabilisticSumLayer (ProbabilisticKnowledgeLayer ):
124+ def forward (self , x ):
125+ x = self .get_edge_weights () * x [self .ptrs ]
126+ return self ._scatter_reduce (x , "sum" )
127+
128+ def get_edge_weights (self ):
129+ exp_weights = torch .exp (self .weights )
130+ norm = self ._scatter_reduce (exp_weights , "sum" )
131+ return exp_weights / norm
132+
133+
134+ class ProbabilisticLogSumLayer (ProbabilisticKnowledgeLayer ):
135+ def forward (self , x , epsilon = 10e-16 ):
136+ x = self .get_edge_weights (epsilon ) + x [self .ptrs ]
137+ return self ._logsumexp_scatter_reduce (x , epsilon )
138+
139+ def get_edge_weights (self , epsilon ):
140+ norm = self ._logsumexp_scatter_reduce (self .weights , epsilon )
141+ return self .weights - norm
114142
115143
116- def get_semiring (name : str ):
144+ def get_semiring (name : str , probabilistic : bool ):
117145 """
118146 For a given semiring, returns the sum and product layer,
119147 the zero and one elements, and a negation function.
120148 """
121- if name == "real" :
122- return SumLayer , ProdLayer , 0 , 1 , negate_real
123- elif name == "log" :
124- return LogSumLayer , SumLayer , float ('-inf' ), 0 , log1mexp
125- elif name == "mpe" :
126- return MaxLayer , ProdLayer , 0 , 1 , negate_real
127- elif name == "godel" :
128- return MaxLayer , MinLayer , 0 , 1 , negate_real
149+ if probabilistic :
150+ if name == "real" :
151+ return ProbabilisticSumLayer , ProdLayer , 0 , 1 , negate_real
152+ if name == "log" :
153+ return ProbabilisticLogSumLayer , SumLayer , float ('-inf' ), 0 , log1mexp
154+ raise ValueError (f"Unknown probabilistic semiring { name } " )
129155 else :
156+ if name == "real" :
157+ return SumLayer , ProdLayer , 0 , 1 , negate_real
158+ elif name == "log" :
159+ return LogSumLayer , SumLayer , float ('-inf' ), 0 , log1mexp
160+ elif name == "mpe" :
161+ return MaxLayer , ProdLayer , 0 , 1 , negate_real
162+ elif name == "godel" :
163+ return MaxLayer , MinLayer , 0 , 1 , negate_real
130164 raise ValueError (f"Unknown semiring { name } " )
0 commit comments