10
10
from chalk .backend .cairo import ToList
11
11
from chalk .trace import Trace
12
12
import jax
13
- jax .config .update ("jax_debug_nans" , True )
13
+ # jax.config.update("jax_debug_nans", True)
14
14
# blue = Color("#005FDB")
15
15
# d = rectangle(10, 10).rotate(30) #| rectangle(10, 10).align_t() | circle(50).align_t()
16
16
22
22
def make_scene (splits , mask , scene ):
23
23
split_int = np .floor (splits ).astype (int )
24
24
loc = np .arange (splits .shape [0 ]) % 2
25
- scene = scene .at [split_int * mask ].set (2 * (1 - loc ) - 1 )
25
+ scene = scene .at [split_int * mask ].add (2 * (1 - loc ) - 1 )
26
26
scene = scene .at [0 ].set (0 )
27
27
scene = np .cumsum (scene )
28
28
scene = scene .at [split_int * mask ].set (1.0 * (1 - loc ) + (2 * loc - 1 ) * ((splits ) - split_int ))
29
29
scene = scene .at [0 ].set (0 )
30
- out = jax .vmap (lambda s : scene [s + samples ] @ gaussian_kernel )(np .arange (scene .shape [0 ]))
30
+ out = jax .vmap (lambda s : scene [s + samples ] @ kernel ( 0 ) )(np .arange (scene .shape [0 ]))
31
31
return scene , out
32
32
33
33
def f_fwd (x , mask , scene ):
@@ -38,18 +38,21 @@ def f_bwd(res, g):
38
38
_ , g = g
39
39
splits , mask , scene = res
40
40
split_int = np .floor (splits ).astype (int )
41
- def grad_p (s ):
41
+ def grad_p (s , s_off ):
42
+ off = s_off - s
42
43
v = g [s + samples ]
43
- return (v @ gaussian_kernel ) * (scene [s - 1 ] - scene [s + 1 ])
44
- r = jax .vmap (grad_p ) (split_int ) * mask
44
+ return (v @ kernel ( off ) ) * (scene [s - 1 ] - scene [s + 1 ])
45
+ r = jax .vmap (grad_p , in_axes = ( 0 , 0 )) (split_int , splits ) * mask
45
46
return r , None , None
46
47
47
48
make_scene .defvjp (f_fwd , f_bwd )
48
49
49
50
kern = 11
50
51
samples = np .arange (kern ) - (kern // 2 )
51
- gaussian_kernel = np .exp (- ((samples / 2 ) ** 2 ))
52
- gaussian_kernel = gaussian_kernel / gaussian_kernel .sum ()
52
+ def kernel (offset ):
53
+ off_samples = samples - offset
54
+ gaussian_kernel = kern - np .abs (off_samples )
55
+ return np .maximum (0 , gaussian_kernel / (kern - np .abs (samples )).sum ())
53
56
54
57
def render_out (d , x = 200 , y = 200 ):
55
58
t = 4 * x
@@ -59,7 +62,7 @@ def color_row(i, counter):
59
62
idx = i // (t // y )
60
63
return make_scene (ps [i ], m [i ], counter )
61
64
i = np .arange (1 , ps .shape [0 ], 4 )
62
- scene , out1 = jax .vmap (color_row , in_axes = (0 , 0 ))(i , counter [i ])
65
+ scene , out1 = jax .vmap (color_row , in_axes = (0 , 0 ))(i , counter [i // 4 ])
63
66
# out1 = jax.vmap(lambda c:
64
67
# jax.vmap(lambda s: c[s + samples] @ gaussian_kernel)(np.arange(x)))(counter)
65
68
#out = out1[..., None] * np.array(color)
0 commit comments