@@ -29,7 +29,7 @@ def encode_input_log(pos, neg):
2929 neg = log1mexp (pos )
3030
3131 result = jnp .stack ([pos , neg ], axis = 1 ).flatten ()
32- constants = jnp .array ([float (' -inf' ), 0 ], dtype = jnp .float32 )
32+ constants = jnp .array ([float (" -inf" ), 0 ], dtype = jnp .float32 )
3333 return jnp .concat ([constants , result ])
3434
3535
@@ -38,17 +38,22 @@ def encode_input_real(pos, neg):
3838 neg = 1 - pos
3939
4040 result = jnp .stack ([pos , neg ], axis = 1 ).flatten ()
41- constants = jnp .array ([0. , 1 ,], dtype = jnp .float32 )
41+ constants = jnp .array (
42+ [
43+ 0.0 ,
44+ 1 ,
45+ ],
46+ dtype = jnp .float32 ,
47+ )
4248 return jnp .concat ([constants , result ])
4349
4450
45-
4651def create_knowledge_layer (pointers , csrs , semiring ):
4752 pointers = [np .array (ptrs ) for ptrs in pointers ]
4853 num_segments = [len (csr ) - 1 for csr in csrs ] # needed for the jit
4954 csrs = [unroll_csr (np .array (csr , dtype = np .int32 )) for csr in csrs ]
5055 sum_layer , prod_layer = get_semiring (semiring )
51- encode_input = {' log' : encode_input_log , ' real' : encode_input_real }[semiring ]
56+ encode_input = {" log" : encode_input_log , " real" : encode_input_real }[semiring ]
5257
5358 @jax .jit
5459 def wrapper (pos , neg = None ):
@@ -69,29 +74,75 @@ def unroll_csr(csr):
6974 return np .repeat (ixs , repeats = deltas )
7075
7176
72- def log_sum_layer (num_segments , ptrs , csr , x ):
73- x = x [ptrs ]
74- x_max = segment_max (stop_gradient (x ), csr , indices_are_sorted = True , num_segments = num_segments )
77+ def exp_max (num_segments , csr , x ):
78+ x_max = segment_max (
79+ stop_gradient (x ), csr , indices_are_sorted = True , num_segments = num_segments
80+ )
7581 x = x - x_max [csr ]
76- x = jnp .nan_to_num (x , copy = False , nan = 0.0 , posinf = float ('inf' ), neginf = float ('-inf' ))
82+ x = jnp .nan_to_num (
83+ x , copy = False , nan = 0.0 , posinf = float ("inf" ), neginf = float ("-inf" )
84+ )
7785 x = jnp .exp (x )
86+ return x , x_max
87+
88+
89+ @jax .custom_jvp
90+ def log_sum_layer (num_segments , ptrs , csr , x ):
91+ x = x [ptrs ]
92+ x , x_max = exp_max (num_segments , csr , x )
7893 x = segment_sum (x , csr , indices_are_sorted = True , num_segments = num_segments )
7994 x = jnp .log (x + EPSILON ) + x_max
8095 return x
8196
8297
98+ @log_sum_layer .defjvp
99+ def log_prod_layer_jvp (num_segments , ptrs , csr , p_in , d_in ):
100+ p_out = log_sum_layer (num_segments , ptrs , csr , p_in )
101+
102+ sign_d_in , mag_d_in = d_in
103+ sign_d = 1 - 2 * sign_d_in
104+ mag_d = mag_d_in - p_in + p_out [csr ]
105+
106+ mag_d , mag_d_max = exp_max (num_segments , csr , mag_d )
107+ mag_d = mag_d * sign_d
108+ mag_d = segment_sum (mag_d , csr , indices_are_sorted = True , num_segments = num_segments )
109+
110+ sign_d_out = (jnp .sign (- mag_d ) + 1 ) // 2
111+ mag_d_out = jnp .log (jnp .abs (mag_d ) + EPSILON ) + mag_d_max
112+
113+ return p_out , (sign_d_out , mag_d_out )
114+
115+
116+ @jax .custom_jvp
117+ def log_prod_layer (num_segments , ptrs , csr , x ):
118+ return sum_layer (num_segments , ptrs , csr , x )
119+
120+
121+ @log_prod_layer .defjvp
122+ def log_prod_layer_jvp (num_segments , ptrs , csr , p_in , d_in ):
123+ p_out = log_prod_layer (num_segments , ptrs , csr , p_in )
124+
125+ sign_d_in , mag_d_in = d_in
126+ d_out_sign = sum_layer (num_segments , ptrs , csr , sign_d_in ) % 2
127+ d_out_mag = sum_layer (num_segments , ptrs , csr , mag_d_in )
128+
129+ return p_out , (d_out_sign , d_out_mag )
130+
131+
83132def sum_layer (num_segments , ptrs , csr , x ):
84133 return segment_sum (x [ptrs ], csr , num_segments = num_segments , indices_are_sorted = True )
85134
86135
87136def prod_layer (num_segments , ptrs , csr , x ):
88- return segment_prod (x [ptrs ], csr , num_segments = num_segments , indices_are_sorted = True )
137+ return segment_prod (
138+ x [ptrs ], csr , num_segments = num_segments , indices_are_sorted = True
139+ )
89140
90141
91142def get_semiring (name : str ):
92- if name == ' real' :
143+ if name == " real" :
93144 return sum_layer , prod_layer
94- elif name == ' log' :
95- return log_sum_layer , sum_layer
145+ elif name == " log" :
146+ return log_sum_layer , log_prod_layer
96147 else :
97- raise ValueError (f"Unknown semiring { name } " )
148+ raise ValueError (f"Unknown semiring { name } " )
0 commit comments