1414from pathlib import Path
1515
1616# Create the app and load the slang module
17- app = App (width = 512 * 3 + 10 * 2 , height = 512 , title = "Neural Demo (FFLayer + DifferentialPtrPair)" , device_type = spy .DeviceType .vulkan )
17+ app = App (
18+ width = 512 * 3 + 10 * 2 ,
19+ height = 512 ,
20+ title = "Neural Demo (FFLayer + DifferentialPtrPair)" ,
21+ device_type = spy .DeviceType .vulkan ,
22+ )
1823module = spy .Module .load_from_file (app .device , "neural-demo.slang" )
1924
2025# Load reference image
@@ -28,7 +33,7 @@ def create_buffer(data: np.ndarray):
2833 element_count = data .size ,
2934 struct_size = 4 ,
3035 usage = spy .BufferUsage .shader_resource | spy .BufferUsage .unordered_access ,
31- data = data .astype ("float32" ).flatten ()
36+ data = data .astype ("float32" ).flatten (),
3237 )
3338
3439
@@ -38,7 +43,7 @@ def create_zero_buffer(size: int):
3843 element_count = size ,
3944 struct_size = 4 ,
4045 usage = spy .BufferUsage .shader_resource | spy .BufferUsage .unordered_access ,
41- data = np .zeros (size , dtype = "float32" )
46+ data = np .zeros (size , dtype = "float32" ),
4247 )
4348
4449
@@ -63,7 +68,14 @@ def __init__(self, width: int, height: int, num_latents: int):
6368
6469 def optimize (self , learning_rate : float , iteration : int ):
6570 # Use the actual Tensors for optimizer
66- module .optimizer_step (self ._texture_tensor , self ._texture_grads_tensor , self ._m_texture , self ._v_texture , learning_rate , iteration )
71+ module .optimizer_step (
72+ self ._texture_tensor ,
73+ self ._texture_grads_tensor ,
74+ self ._m_texture ,
75+ self ._v_texture ,
76+ learning_rate ,
77+ iteration ,
78+ )
6779
6880
6981class Network (spy .InstanceList ):
@@ -73,7 +85,7 @@ def __init__(self):
7385 super ().__init__ (module ["Network" ])
7486
7587 # Layer sizes: 4 -> 32 -> 32 -> 3
76- layer0_params = 32 * 4 + 32 # weights + biases
88+ layer0_params = 32 * 4 + 32 # weights + biases
7789 layer1_params = 32 * 32 + 32
7890 layer2_params = 3 * 32 + 3
7991 self .total_params = layer0_params + layer1_params + layer2_params
@@ -84,17 +96,17 @@ def __init__(self):
8496
8597 # Layer 0: 4 -> 32
8698 scale = np .sqrt (6.0 / (4 + 32 ))
87- params_np [offset : offset + 32 * 4 ] = np .random .uniform (- scale , scale , 32 * 4 )
88- offset += 32 * 4 + 32 # weights + biases
99+ params_np [offset : offset + 32 * 4 ] = np .random .uniform (- scale , scale , 32 * 4 )
100+ offset += 32 * 4 + 32 # weights + biases
89101
90102 # Layer 1: 32 -> 32
91103 scale = np .sqrt (6.0 / (32 + 32 ))
92- params_np [offset : offset + 32 * 32 ] = np .random .uniform (- scale , scale , 32 * 32 )
93- offset += 32 * 32 + 32
104+ params_np [offset : offset + 32 * 32 ] = np .random .uniform (- scale , scale , 32 * 32 )
105+ offset += 32 * 32 + 32
94106
95107 # Layer 2: 32 -> 3
96108 scale = np .sqrt (6.0 / (32 + 3 ))
97- params_np [offset : offset + 3 * 32 ] = np .random .uniform (- scale , scale , 3 * 32 )
109+ params_np [offset : offset + 3 * 32 ] = np .random .uniform (- scale , scale , 3 * 32 )
98110
99111 # Create tensors for optimizer (underscore prefix avoids struct binding issues)
100112 self ._params_tensor = spy .Tensor .from_numpy (app .device , params_np )
@@ -113,7 +125,14 @@ def __init__(self):
113125
114126 def optimize (self , learning_rate : float , iteration : int ):
115127 # Optimize MLP params using Tensors (for vectorization)
116- module .optimizer_step (self ._params_tensor , self ._params_grad_tensor , self ._m_params , self ._v_params , learning_rate , iteration )
128+ module .optimizer_step (
129+ self ._params_tensor ,
130+ self ._params_grad_tensor ,
131+ self ._m_params ,
132+ self ._v_params ,
133+ learning_rate ,
134+ iteration ,
135+ )
117136 # Optimize latent texture
118137 self .latent_texture .optimize (learning_rate , iteration )
119138
@@ -139,12 +158,16 @@ def optimize(self, learning_rate: float, iteration: int):
139158 # Render using unified network
140159 lr_output = spy .Tensor .empty_like (image )
141160 module .render (pixel = spy .call_id (), resolution = res , network = network , _result = lr_output )
142- app .blit (lr_output , size = spy .int2 (512 ), offset = spy .int2 (offset , 0 ), tonemap = False , bilinear = True )
161+ app .blit (
162+ lr_output , size = spy .int2 (512 ), offset = spy .int2 (offset , 0 ), tonemap = False , bilinear = True
163+ )
143164 offset += 522
144165
145166 # Show loss
146167 loss_output = spy .Tensor .empty_like (image )
147- module .loss (pixel = spy .call_id (), resolution = res , network = network , reference = image , _result = loss_output )
168+ module .loss (
169+ pixel = spy .call_id (), resolution = res , network = network , reference = image , _result = loss_output
170+ )
148171 app .blit (loss_output , size = spy .int2 (512 ), offset = spy .int2 (offset , 0 ), tonemap = False )
149172
150173 learning_rate = 0.001
0 commit comments