Skip to content

Commit 905789a

Browse files
committed
implemented forward AD for log semiing in jax (nothgin tested)
1 parent 6cc9e99 commit 905789a

File tree

1 file changed

+64
-13
lines changed

1 file changed

+64
-13
lines changed

src/klay/backends/jax_backend.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def encode_input_log(pos, neg):
2929
neg = log1mexp(pos)
3030

3131
result = jnp.stack([pos, neg], axis=1).flatten()
32-
constants = jnp.array([float('-inf'), 0], dtype=jnp.float32)
32+
constants = jnp.array([float("-inf"), 0], dtype=jnp.float32)
3333
return jnp.concat([constants, result])
3434

3535

@@ -38,17 +38,22 @@ def encode_input_real(pos, neg):
3838
neg = 1 - pos
3939

4040
result = jnp.stack([pos, neg], axis=1).flatten()
41-
constants = jnp.array([0., 1,], dtype=jnp.float32)
41+
constants = jnp.array(
42+
[
43+
0.0,
44+
1,
45+
],
46+
dtype=jnp.float32,
47+
)
4248
return jnp.concat([constants, result])
4349

4450

45-
4651
def create_knowledge_layer(pointers, csrs, semiring):
4752
pointers = [np.array(ptrs) for ptrs in pointers]
4853
num_segments = [len(csr) - 1 for csr in csrs] # needed for the jit
4954
csrs = [unroll_csr(np.array(csr, dtype=np.int32)) for csr in csrs]
5055
sum_layer, prod_layer = get_semiring(semiring)
51-
encode_input = {'log': encode_input_log, 'real': encode_input_real}[semiring]
56+
encode_input = {"log": encode_input_log, "real": encode_input_real}[semiring]
5257

5358
@jax.jit
5459
def wrapper(pos, neg=None):
@@ -69,29 +74,75 @@ def unroll_csr(csr):
6974
return np.repeat(ixs, repeats=deltas)
7075

7176

72-
def log_sum_layer(num_segments, ptrs, csr, x):
73-
x = x[ptrs]
74-
x_max = segment_max(stop_gradient(x), csr, indices_are_sorted=True, num_segments=num_segments)
77+
def exp_max(num_segments, csr, x):
78+
x_max = segment_max(
79+
stop_gradient(x), csr, indices_are_sorted=True, num_segments=num_segments
80+
)
7581
x = x - x_max[csr]
76-
x = jnp.nan_to_num(x, copy=False, nan=0.0, posinf=float('inf'), neginf=float('-inf'))
82+
x = jnp.nan_to_num(
83+
x, copy=False, nan=0.0, posinf=float("inf"), neginf=float("-inf")
84+
)
7785
x = jnp.exp(x)
86+
return x, x_max
87+
88+
89+
@jax.custom_jvp
90+
def log_sum_layer(num_segments, ptrs, csr, x):
91+
x = x[ptrs]
92+
x, x_max = exp_max(num_segments, csr, x)
7893
x = segment_sum(x, csr, indices_are_sorted=True, num_segments=num_segments)
7994
x = jnp.log(x + EPSILON) + x_max
8095
return x
8196

8297

98+
@log_sum_layer.defjvp
99+
def log_prod_layer_jvp(num_segments, ptrs, csr, p_in, d_in):
100+
p_out = log_sum_layer(num_segments, ptrs, csr, p_in)
101+
102+
sign_d_in, mag_d_in = d_in
103+
sign_d = 1 - 2 * sign_d_in
104+
mag_d = mag_d_in - p_in + p_out[csr]
105+
106+
mag_d, mag_d_max = exp_max(num_segments, csr, mag_d)
107+
mag_d = mag_d * sign_d
108+
mag_d = segment_sum(mag_d, csr, indices_are_sorted=True, num_segments=num_segments)
109+
110+
sign_d_out = (jnp.sign(-mag_d) + 1) // 2
111+
mag_d_out = jnp.log(jnp.abs(mag_d) + EPSILON) + mag_d_max
112+
113+
return p_out, (sign_d_out, mag_d_out)
114+
115+
116+
@jax.custom_jvp
117+
def log_prod_layer(num_segments, ptrs, csr, x):
118+
return sum_layer(num_segments, ptrs, csr, x)
119+
120+
121+
@log_prod_layer.defjvp
122+
def log_prod_layer_jvp(num_segments, ptrs, csr, p_in, d_in):
123+
p_out = log_prod_layer(num_segments, ptrs, csr, p_in)
124+
125+
sign_d_in, mag_d_in = d_in
126+
d_out_sign = sum_layer(num_segments, ptrs, csr, sign_d_in) % 2
127+
d_out_mag = sum_layer(num_segments, ptrs, csr, mag_d_in)
128+
129+
return p_out, (d_out_sign, d_out_mag)
130+
131+
83132
def sum_layer(num_segments, ptrs, csr, x):
84133
return segment_sum(x[ptrs], csr, num_segments=num_segments, indices_are_sorted=True)
85134

86135

87136
def prod_layer(num_segments, ptrs, csr, x):
88-
return segment_prod(x[ptrs], csr, num_segments=num_segments, indices_are_sorted=True)
137+
return segment_prod(
138+
x[ptrs], csr, num_segments=num_segments, indices_are_sorted=True
139+
)
89140

90141

91142
def get_semiring(name: str):
92-
if name == 'real':
143+
if name == "real":
93144
return sum_layer, prod_layer
94-
elif name == 'log':
95-
return log_sum_layer, sum_layer
145+
elif name == "log":
146+
return log_sum_layer, log_prod_layer
96147
else:
97-
raise ValueError(f"Unknown semiring {name}")
148+
raise ValueError(f"Unknown semiring {name}")

0 commit comments

Comments
 (0)