@@ -28,26 +28,26 @@ def encode_input(pos, neg, zero, one):
2828 return torch .cat ([constants , result ])
2929
3030
31- def unroll_csr ( csr ):
32- deltas = torch .diff (csr )
33- ixs = torch .arange (len (deltas ), dtype = torch .long , device = csr .device )
31+ def unroll_ixs ( ixs ):
32+ deltas = torch .diff (ixs )
33+ ixs = torch .arange (len (deltas ), dtype = torch .long , device = ixs .device )
3434 return ixs .repeat_interleave (repeats = deltas )
3535
3636
3737class KnowledgeModule (nn .Module ):
38- def __init__ (self , pointers , csrs , semiring = 'real' , probabilistic = False ):
38+ def __init__ (self , ixs_in , ixs_out , semiring = 'real' , probabilistic = False ):
3939 super (KnowledgeModule , self ).__init__ ()
4040 layers = []
4141 self .probabilistic = probabilistic
4242 sum_layer , prod_layer , self .zero , self .one , self .negate = get_semiring (semiring , probabilistic )
43- for i , (ptrs , csr ) in enumerate (zip (pointers , csrs )):
44- ptrs = torch .as_tensor (ptrs )
45- csr = torch .as_tensor (csr , dtype = torch .long )
46- csr = unroll_csr ( csr )
43+ for i , (ix_in , ix_out ) in enumerate (zip (ixs_in , ixs_out )):
44+ ix_in = torch .as_tensor (ix_in )
45+ ix_out = torch .as_tensor (ix_out , dtype = torch .long )
46+ ix_out = unroll_ixs ( ix_out )
4747 if i % 2 == 0 :
48- layers .append (prod_layer (ptrs , csr ))
48+ layers .append (prod_layer (ix_in , ix_out ))
4949 else :
50- layers .append (sum_layer (ptrs , csr ))
50+ layers .append (sum_layer (ix_in , ix_out ))
5151 self .layers = nn .Sequential (* layers )
5252
5353 def forward (self , weights , neg_weights = None , eps = 0 ):
@@ -57,7 +57,7 @@ def forward(self, weights, neg_weights=None, eps=0):
5757 return self .layers (x )
5858
5959 def sparsity (self , nb_vars : int ) -> float :
60- sparse_params = sum (len (l .csr ) for l in self .layers )
60+ sparse_params = sum (len (l .ix_out ) for l in self .layers )
6161 layer_widths = [nb_vars ] + [l .out_shape [0 ] for l in self .layers ]
6262 dense_params = sum (layer_widths [i ] * layer_widths [i + 1 ] for i in range (len (layer_widths ) - 1 ))
6363 return sparse_params / dense_params
@@ -71,104 +71,104 @@ def sample_pc(self):
7171
7272
7373class KnowledgeLayer (nn .Module ):
74- def __init__ (self , ptrs , csr ):
74+ def __init__ (self , ix_in , ix_out ):
7575 super ().__init__ ()
76- self .register_buffer ('ptrs ' , ptrs )
77- self .register_buffer ('csr ' , csr )
78- self .out_shape = (self .csr [- 1 ].item () + 1 ,)
79- self .in_shape = (self .ptrs .max () + 1 ,)
76+ self .register_buffer ('ix_in ' , ix_in )
77+ self .register_buffer ('ix_out ' , ix_out )
78+ self .out_shape = (self .ix_out [- 1 ].item () + 1 ,)
79+ self .in_shape = (self .ix_in .max (). item () + 1 ,)
8080
8181 def _scatter_forward (self , x : torch .Tensor , reduce : str ):
8282 output = torch .empty (self .out_shape , dtype = x .dtype , device = x .device )
83- output = torch .scatter_reduce (output , 0 , index = self .csr , src = x , reduce = reduce , include_self = False )
83+ output = torch .scatter_reduce (output , 0 , index = self .ix_out , src = x , reduce = reduce , include_self = False )
8484 return output
8585
8686 def _scatter_backward (self , x : torch .Tensor , reduce : str ):
8787 output = torch .empty (self .in_shape , dtype = x .dtype , device = x .device )
88- output = torch .scatter_reduce (output , 0 , index = self .ptrs , src = x , reduce = reduce , include_self = False )
88+ output = torch .scatter_reduce (output , 0 , index = self .ix_in , src = x , reduce = reduce , include_self = False )
8989 return output
9090
9191
9292 def _safe_exp (self , x : torch .Tensor ):
9393 with torch .no_grad ():
9494 max_output = self ._scatter_forward (x , "amax" )
95- x = x - max_output [self .csr ]
95+ x = x - max_output [self .ix_out ]
9696 x .nan_to_num_ (nan = 0. , posinf = float ('inf' ), neginf = float ('-inf' ))
9797 return torch .exp (x ), max_output
9898
9999 def _logsumexp_scatter (self , x : torch .Tensor , eps : float ):
100100 x , max_output = self ._safe_exp (x )
101101 output = torch .full (self .out_shape , eps , dtype = x .dtype , device = x .device )
102- output = torch .scatter_add (output , 0 , index = self .csr , src = x )
102+ output = torch .scatter_add (output , 0 , index = self .ix_out , src = x )
103103 output = torch .log (output ) + max_output
104104 return output
105105
106106
107107
108108class ProbabilisticKnowledgeLayer (KnowledgeLayer ):
109- def __init__ (self , ptrs , csr ):
110- super ().__init__ (ptrs , csr )
111- self .weights = nn .Parameter (torch .randn_like (ptrs , dtype = torch .float32 ))
109+ def __init__ (self , ix_in , ix_out ):
110+ super ().__init__ (ix_in , ix_out )
111+ self .weights = nn .Parameter (torch .randn_like (ix_in , dtype = torch .float32 ))
112112
113113 def get_edge_weights (self ):
114114 exp_weights , _ = self ._safe_exp (self .weights )
115115 norm = self ._scatter_forward (exp_weights , "sum" )
116- return exp_weights / norm [self .csr ]
116+ return exp_weights / norm [self .ix_out ]
117117
118118 def get_log_edge_weights (self , eps ):
119119 norm = self ._logsumexp_scatter (self .weights , eps )
120- return self .weights - norm [self .csr ]
120+ return self .weights - norm [self .ix_out ]
121121
122122 def sample_pc (self , y , eps = 10e-16 ):
123123 weights = self .get_log_edge_weights (eps )
124124 noise = - (- torch .log (torch .rand_like (weights ) + eps ) + eps ).log ()
125125 gumbels = weights + noise
126126 samples = self ._scatter_forward (gumbels , "amax" )
127- samples = samples [self .csr ] == gumbels
128- samples &= y [self .csr ].to (torch .bool )
127+ samples = samples [self .ix_out ] == gumbels
128+ samples &= y [self .ix_out ].to (torch .bool )
129129 return self ._scatter_backward (samples , "sum" ) > 0
130130
131131
132132class SumLayer (KnowledgeLayer ):
133133 def forward (self , x ):
134- return self ._scatter_forward (x [self .ptrs ], "sum" )
134+ return self ._scatter_forward (x [self .ix_in ], "sum" )
135135
136136 def sample_pc (self , y ):
137- return self ._scatter_backward (y [self .csr ], "sum" ) > 0
137+ return self ._scatter_backward (y [self .ix_out ], "sum" ) > 0
138138
139139
140140class ProdLayer (KnowledgeLayer ):
141141 def forward (self , x ):
142- return self ._scatter_forward (x [self .ptrs ], "prod" )
142+ return self ._scatter_forward (x [self .ix_in ], "prod" )
143143
144144 def sample_pc (self , y ):
145- return self ._scatter_backward (y [self .csr ], "sum" ) > 0
145+ return self ._scatter_backward (y [self .ix_out ], "sum" ) > 0
146146
147147
148148class MinLayer (KnowledgeLayer ):
149149 def forward (self , x ):
150- return self ._scatter_forward (x [self .ptrs ], "amin" )
150+ return self ._scatter_forward (x [self .ix_in ], "amin" )
151151
152152
153153class MaxLayer (KnowledgeLayer ):
154154 def forward (self , x ):
155- return self ._scatter_forward (x [self .ptrs ], "amax" )
155+ return self ._scatter_forward (x [self .ix_in ], "amax" )
156156
157157
158158class LogSumLayer (KnowledgeLayer ):
159159 def forward (self , x , eps = 10e-16 ):
160- return self ._logsumexp_scatter (x [self .ptrs ], eps )
160+ return self ._logsumexp_scatter (x [self .ix_in ], eps )
161161
162162
163163class ProbabilisticSumLayer (ProbabilisticKnowledgeLayer ):
164164 def forward (self , x ):
165- x = self .get_edge_weights () * x [self .ptrs ]
165+ x = self .get_edge_weights () * x [self .ix_in ]
166166 return self ._scatter_forward (x , "sum" )
167167
168168
169169class ProbabilisticLogSumLayer (ProbabilisticKnowledgeLayer ):
170170 def forward (self , x , eps = 10e-16 ):
171- x = self .get_log_edge_weights (eps ) + x [self .ptrs ]
171+ x = self .get_log_edge_weights (eps ) + x [self .ix_in ]
172172 return self ._logsumexp_scatter (x , eps )
173173
174174
0 commit comments