Skip to content

Commit 5065b61

Browse files
committed
Add demo for neural.slang
1 parent 6d2c37e commit 5065b61

File tree

6 files changed

+590
-0
lines changed

6 files changed

+590
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Neural Demo - Using neural.slang
2+
3+
This demo showcases how to use Slang's `neural.slang` standard module to build a neural network for image reconstruction. The network learns to map UV coordinates to RGB colors, reconstructing a reference image through gradient-based optimization.
4+
This is a re-creation of the texture example in the https://github.com/shader-slang/neural-shading-s25 course.
5+
6+
## neural.slang Types Used
7+
8+
| Type | Description |
9+
|------|-------------|
10+
| `InlineVector<T, N>` | Fixed-size vector type with compile-time `.Size` constant |
11+
| `StructuredBufferStorage<T>` | GPU buffer storage implementing `IStorage<T>` interface |
12+
| `FFLayer<T, InVec, OutVec, Storage, Activation, HasBias>` | Feed-forward neural network layer |
13+
| `IdentityActivation<T>` | Pass-through activation (no transformation) |
14+
| `NoParam()` | Empty parameter for activations that don't need configuration |
15+
16+
## Before/After Comparison
17+
18+
This section shows the comparison between the original code and the code with the neural.slang APIs
19+
20+
### Lines of Code
21+
22+
Approximate lines of code comparison apart from comments.
23+
24+
| File Type | Original | neural.slang |
25+
|-----------|----------|--------------|
26+
| Slang | 171 | 104 |
27+
| Python | 103 | 66 |
28+
29+
### Vector Types
30+
31+
| Before (Manual) | After (neural.slang) |
32+
|-----------------|---------------------|
33+
| `float[4]` / `float4` | `InlineVector<float, 4>` |
34+
| `float[32]` | `InlineVector<float, 32>` |
35+
| `float[3]` / `float3` | `InlineVector<float, 3>` |
36+
| Manual size tracking | `Vec4.Size` compile-time constant |
37+
38+
### Parameter Storage
39+
40+
| Before (Manual) | After (neural.slang) |
41+
|-----------------|---------------------|
42+
| Separate weight/bias buffers | `StructuredBufferStorage<T>` wrapper |
43+
| Manual offset calculation | `Storage.getOffset()` method |
44+
| Manual parameter count | `FFLayer.ParameterCount` constant |
45+
46+
### Layer Forward Pass
47+
48+
| Before (Manual) | After (neural.slang) |
49+
|-----------------|---------------------|
50+
| Manual matrix multiply | `FFLayer.eval()` using `linearTransform` |
51+
| Explicit loops | Optimized internal implementation |
52+
| Manual bias addition | Handled by `FFLayer` |
53+
54+
**Before:**
55+
```slang
56+
[Differentiable]
57+
float[Outputs] forward(float[Inputs] x)
58+
{
59+
float[Outputs] y;
60+
[MaxIters(Outputs)]
61+
for (int row = 0; row < Outputs; ++row)
62+
{
63+
var sum = get_bias(row);
64+
[ForceUnroll]
65+
for (int col = 0; col < Inputs; ++col)
66+
sum += get_weight(row, col) * x[col];
67+
y[row] = sum;
68+
}
69+
return y;
70+
}
71+
```
72+
73+
**After:**
74+
```slang
75+
[Differentiable]
76+
OutputVec mlp_forward(Storage storage, InputVec input)
77+
{
78+
uint addr = 0u;
79+
let h0 = Layer0(addr, addr + INPUT_DIM * HIDDEN_DIM, LeakyReLU<float>(LEAKY_RELU_SLOPE)).eval<Storage>(storage, input);
80+
addr = Layer0.nextAddress(addr);
81+
let h1 = Layer1(addr, addr + HIDDEN_DIM * HIDDEN_DIM, LeakyReLU<float>(LEAKY_RELU_SLOPE)).eval<Storage>(storage, h0);
82+
addr = Layer1.nextAddress(addr);
83+
return Layer2(addr, addr + HIDDEN_DIM * OUTPUT_DIM, ExpActivation<float>()).eval<Storage>(storage, h1);
84+
}
85+
```
86+
87+
### Network Definition
88+
89+
| Before (Manual) | After (neural.slang) |
90+
|-----------------|---------------------|
91+
| Custom struct with manual layout | Type aliases for layers |
92+
| Hardcoded dimensions | Dimensions from vector types |
93+
| Manual weight indexing | Automatic address calculation |
94+
95+
**Before:**
96+
```slang
97+
struct Network
98+
{
99+
RWStructuredBuffer<float> layer0_weights; // 4*32 floats
100+
RWStructuredBuffer<float> layer0_biases; // 32 floats
101+
RWStructuredBuffer<float> layer1_weights; // 32*32 floats
102+
RWStructuredBuffer<float> layer1_biases; // 32 floats
103+
RWStructuredBuffer<float> layer2_weights; // 32*3 floats
104+
RWStructuredBuffer<float> layer2_biases; // 3 floats
105+
106+
[Differentiable]
107+
float3 forward(float4 input) { /* manual implementation */ }
108+
}
109+
```
110+
111+
**After:**
112+
```slang
113+
import neural;
114+
115+
// Type definitions using neural.slang
116+
typealias Vec4 = InlineVector<float, 4>;
117+
typealias Vec32 = InlineVector<float, 32>;
118+
typealias Vec3 = InlineVector<float, 3>;
119+
typealias Storage = StructuredBufferStorage<float>;
120+
typealias Act = IdentityActivation<float>;
121+
122+
typealias Layer0Type = FFLayer<float, Vec4, Vec32, Storage, Act, true>;
123+
typealias Layer1Type = FFLayer<float, Vec32, Vec32, Storage, Act, true>;
124+
typealias Layer2Type = FFLayer<float, Vec32, Vec3, Storage, Act, true>;
125+
126+
struct MLPNetwork
127+
{
128+
// One buffer per layer: [weights, biases] contiguous
129+
RWStructuredBuffer<float> layer0_params;
130+
RWStructuredBuffer<float> layer1_params;
131+
RWStructuredBuffer<float> layer2_params;
132+
133+
Vec3 forward(Vec4 input)
134+
{
135+
let storage0 = Storage(layer0_params);
136+
let ff0 = Layer0Type(storage0, 0u, INPUT_SIZE * HIDDEN_SIZE);
137+
Vec32 h0 = ff0.eval(NoParam(), input);
138+
// ...
139+
}
140+
}
141+
```
142+
143+
## Running the Demo
144+
145+
```bash
146+
cd slangpy-samples/examples/neural-demo
147+
python neural-demo.py
148+
```
149+
150+
The demo displays three panels:
151+
1. **Reference image** - Target to reconstruct
152+
2. **Network output** - Current reconstruction using FFLayer-based network
153+
3. **Loss visualization** - Per-pixel error
154+
155+
Loss values are printed to console and should decrease over time as the network learns.
156+

examples/neural_slang_demo/app.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Callable, Optional
4+
import slangpy as spy
5+
from pathlib import Path
6+
7+
8+
class App:
9+
def __init__(
10+
self,
11+
title: str = "Forward Rasterizer Example",
12+
width: int = 1024,
13+
height: int = 1024,
14+
device_type: spy.DeviceType = spy.DeviceType.automatic,
15+
):
16+
super().__init__()
17+
18+
# Create spy window
19+
self._window = spy.Window(width=width, height=height, title=title, resizable=False)
20+
21+
# Create spy device with local include path for shaders
22+
self._device = spy.create_device(
23+
device_type, enable_debug_layers=True, include_paths=[Path(__file__).parent]
24+
)
25+
26+
# Load module of helpers
27+
self._module = spy.Module.load_from_file(self._device, "app.slang")
28+
29+
# Setup swapchain
30+
self.surface = self._device.create_surface(self._window)
31+
self.surface.configure(width=self._window.width, height=self._window.height)
32+
33+
# Will contain output texture
34+
self._output_texture: "spy.Texture" = self.device.create_texture(
35+
format=spy.Format.rgba16_float,
36+
width=width,
37+
height=height,
38+
mip_count=1,
39+
usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access,
40+
label="output_texture",
41+
)
42+
43+
# Store mouse pos
44+
self._mouse_pos = spy.float2()
45+
46+
# Internal events
47+
self._window.on_keyboard_event = self._on_window_keyboard_event
48+
self._window.on_mouse_event = self._on_window_mouse_event
49+
self._window.on_resize = self._on_window_resize
50+
51+
# Hookable events
52+
self.on_keyboard_event: Optional[Callable[[spy.KeyboardEvent], None]] = None
53+
self.on_mouse_event: Optional[Callable[[spy.MouseEvent], None]] = None
54+
55+
@property
56+
def device(self) -> spy.Device:
57+
return self._device
58+
59+
@property
60+
def window(self) -> spy.Window:
61+
return self._window
62+
63+
@property
64+
def mouse_pos(self) -> spy.float2:
65+
return self._mouse_pos
66+
67+
@property
68+
def output(self) -> spy.Texture:
69+
return self._output_texture
70+
71+
def process_events(self):
72+
if self._window.should_close():
73+
return False
74+
self._window.process_events()
75+
return True
76+
77+
def present(self):
78+
if not self.surface.config:
79+
return
80+
image = self.surface.acquire_next_image()
81+
if not image:
82+
return
83+
84+
if (
85+
self._output_texture == None
86+
or self._output_texture.width != image.width
87+
or self._output_texture.height != image.height
88+
):
89+
self._output_texture = self.device.create_texture(
90+
format=spy.Format.rgba16_float,
91+
width=image.width,
92+
height=image.height,
93+
mip_count=1,
94+
usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access,
95+
label="output_texture",
96+
)
97+
98+
command_encoder = self._device.create_command_encoder()
99+
command_encoder.blit(image, self._output_texture)
100+
command_encoder.set_texture_state(image, spy.ResourceState.present)
101+
self._device.submit_command_buffer(command_encoder.finish())
102+
103+
del image
104+
self.surface.present()
105+
106+
def blit(
107+
self,
108+
source: spy.Tensor,
109+
size: Optional[spy.int2] = None,
110+
offset: Optional[spy.int2] = None,
111+
tonemap: bool = True,
112+
bilinear: bool = False,
113+
):
114+
if len(source.shape) != 2:
115+
raise ValueError("Source tensor must be 2D (height, width).")
116+
if size is None:
117+
size = spy.int2(source.shape[1], source.shape[0])
118+
if offset is None:
119+
offset = spy.int2(0, 0)
120+
self._module.blit(
121+
spy.grid((size.y, size.x)), size, offset, tonemap, bilinear, source, self.output
122+
)
123+
124+
def _on_window_keyboard_event(self, event: spy.KeyboardEvent):
125+
if event.type == spy.KeyboardEventType.key_press:
126+
if event.key == spy.KeyCode.escape:
127+
self._window.close()
128+
return
129+
elif event.key == spy.KeyCode.f1:
130+
if self._output_texture:
131+
spy.tev.show_async(self._output_texture)
132+
return
133+
elif event.key == spy.KeyCode.f2:
134+
if self._output_texture:
135+
bitmap = self._output_texture.to_bitmap()
136+
bitmap.convert(
137+
spy.Bitmap.PixelFormat.rgb,
138+
spy.Bitmap.ComponentType.uint8,
139+
srgb_gamma=True,
140+
).write_async("screenshot.png")
141+
return
142+
if self.on_keyboard_event:
143+
self.on_keyboard_event(event)
144+
145+
def _on_window_mouse_event(self, event: spy.MouseEvent):
146+
if event.type == spy.MouseEventType.move:
147+
self._mouse_pos = event.pos
148+
if self.on_mouse_event:
149+
self.on_mouse_event(event)
150+
151+
def _on_window_resize(self, width: int, height: int):
152+
self._device.wait()
153+
if width > 0 and height > 0:
154+
self.surface.configure(width=width, height=height)
155+
else:
156+
self.surface.unconfigure()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
import slangpy;
4+
5+
// Helper function to turn pixel coordinates into normalized UVs.
6+
float2 pixel_to_uv(int2 pixel, int2 resolution)
7+
{
8+
return (float2(pixel) + 0.5f) / float2(resolution);
9+
}
10+
11+
float3 tonemap_aces_film(float3 input)
12+
{
13+
float3 x = input.xyz;
14+
float a = 2.51;
15+
float b = 0.03;
16+
float c = 2.43;
17+
float d = 0.59;
18+
float e = 0.14;
19+
float3 col = saturate((x * (a * x + b)) / (x * (c * x + d) + e));
20+
return col;
21+
}
22+
23+
void blit(
24+
int2 pixel,
25+
int2 size,
26+
int2 offset,
27+
bool tonemap,
28+
bool bilinear,
29+
Tensor<float3, 2> input,
30+
RWTexture2D<float4> output,
31+
)
32+
{
33+
uint2 output_dimensions;
34+
output.GetDimensions(output_dimensions.x, output_dimensions.y);
35+
36+
float2 uv = pixel_to_uv(pixel, size);
37+
38+
uint[2] shape = input.shape;
39+
40+
float3 col = bilinear ?
41+
input.sample(uv) :
42+
input.getv(uint2(float2(shape[1], shape[0]) * uv));
43+
44+
col = abs(col);
45+
46+
if(tonemap)
47+
{
48+
col = tonemap_aces_film(col);
49+
}
50+
51+
if(any(isnan(col.x)))
52+
{
53+
output[offset + pixel] = float4(1,0,1,1);
54+
}
55+
else
56+
{
57+
output[offset + pixel] = float4(col, 1.0f);
58+
59+
}
60+
61+
}

0 commit comments

Comments
 (0)