Skip to content

Commit d7ff92c

Browse files
authored
Merge pull request #12 from pemattern/11-add-option-for-custominput-data-to-tuishaderbackend
feat: user data can be added
2 parents 0f2ce92 + 1dc9b4b commit d7ff92c

File tree

4 files changed

+174
-53
lines changed

4 files changed

+174
-53
lines changed

src/backend/cpu.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use super::{NoUserData, TuiShaderBackend};
2+
3+
pub struct CpuBackend<T> {
4+
callback: Box<dyn CpuShaderCallback<T>>,
5+
}
6+
7+
impl CpuBackend<NoUserData> {
8+
pub fn new<F>(callback: F) -> Self
9+
where
10+
F: Fn(u16, u16) -> [u8; 4] + 'static,
11+
{
12+
Self {
13+
callback: Box::new(CpuShaderCallbackWithoutUserData(callback)),
14+
}
15+
}
16+
}
17+
18+
impl<T> CpuBackend<T> {
19+
pub fn new_with_user_data<F>(callback: F) -> Self
20+
where
21+
F: Fn(u16, u16, &T) -> [u8; 4] + 'static,
22+
{
23+
Self {
24+
callback: Box::new(callback),
25+
}
26+
}
27+
}
28+
29+
impl<T> TuiShaderBackend<T> for CpuBackend<T> {
30+
fn execute(&mut self, width: u16, height: u16, user_data: &T) -> Vec<[u8; 4]> {
31+
let mut pixels = Vec::new();
32+
for y in 0..height {
33+
for x in 0..width {
34+
let value = self.callback.call(x, y, user_data);
35+
pixels.push(value);
36+
}
37+
}
38+
pixels
39+
}
40+
}
41+
42+
pub trait CpuShaderCallback<T> {
43+
fn call(&self, x: u16, y: u16, user_data: &T) -> [u8; 4];
44+
}
45+
46+
impl<T, F> CpuShaderCallback<T> for F
47+
where
48+
F: Fn(u16, u16, &T) -> [u8; 4],
49+
{
50+
fn call(&self, x: u16, y: u16, user_data: &T) -> [u8; 4] {
51+
self(x, y, user_data)
52+
}
53+
}
54+
55+
// NewType required to avoid conflicting implementations
56+
struct CpuShaderCallbackWithoutUserData<F>(F);
57+
impl<F> CpuShaderCallback<NoUserData> for CpuShaderCallbackWithoutUserData<F>
58+
where
59+
F: Fn(u16, u16) -> [u8; 4],
60+
{
61+
fn call(&self, x: u16, y: u16, _user_data: &NoUserData) -> [u8; 4] {
62+
self.0(x, y)
63+
}
64+
}

src/backend/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
pub mod cpu;
12
pub mod wgpu;
23

3-
pub trait TuiShaderBackend: Eq {
4-
fn execute(&mut self, width: u16, height: u16) -> Vec<[u8; 4]>;
4+
#[repr(C)]
5+
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
6+
pub struct NoUserData;
7+
8+
pub trait TuiShaderBackend<T> {
9+
fn execute(&mut self, width: u16, height: u16, user_data: &T) -> Vec<[u8; 4]>;
510
}

src/backend/wgpu.rs

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use std::{fs, time::Instant};
1+
use std::{fs, marker::PhantomData, time::Instant};
22

33
use pollster::FutureExt as _;
44
use wgpu::util::DeviceExt;
55

6-
use super::TuiShaderBackend;
6+
use super::{NoUserData, TuiShaderBackend};
77

88
#[repr(C)]
9-
#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
9+
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1010
struct ShaderInput {
1111
// struct field order matters
1212
time: f32,
@@ -15,7 +15,10 @@ struct ShaderInput {
1515
}
1616

1717
#[derive(Debug, Clone, Eq, PartialEq)]
18-
pub struct WgpuBackend {
18+
pub struct WgpuBackend<T>
19+
where
20+
T: Copy + Clone + bytemuck::Pod + bytemuck::Zeroable,
21+
{
1922
device: wgpu::Device,
2023
queue: wgpu::Queue,
2124
pipeline: wgpu::RenderPipeline,
@@ -26,9 +29,13 @@ pub struct WgpuBackend {
2629
bind_group: wgpu::BindGroup,
2730
width: u16,
2831
height: u16,
32+
_user_data: PhantomData<T>,
2933
}
3034

31-
impl WgpuBackend {
35+
impl<T> WgpuBackend<T>
36+
where
37+
T: Copy + Clone + bytemuck::Pod + bytemuck::Zeroable,
38+
{
3239
pub fn new(path_to_fragment_shader: &str, entry_point: &str) -> Self {
3340
Self::new_inner(path_to_fragment_shader, entry_point).block_on()
3441
}
@@ -70,8 +77,8 @@ impl WgpuBackend {
7077
let width = 64u16;
7178
let height = 64u16;
7279

73-
let texture = WgpuBackend::create_texture(&device, width.into(), height.into());
74-
let buffer = WgpuBackend::create_buffer(&device, width.into(), height.into());
80+
let texture = WgpuBackend::<T>::create_texture(&device, width.into(), height.into());
81+
let output_buffer = WgpuBackend::<T>::create_buffer(&device, width.into(), height.into());
7582

7683
let shader_input = ShaderInput {
7784
time: creation_time.elapsed().as_secs_f32(),
@@ -150,11 +157,12 @@ impl WgpuBackend {
150157
pipeline,
151158
creation_time,
152159
texture,
153-
output_buffer: buffer,
154-
width,
155-
height,
160+
output_buffer,
156161
shader_input_buffer,
157162
bind_group,
163+
width,
164+
height,
165+
_user_data: PhantomData,
158166
}
159167
}
160168

@@ -188,24 +196,23 @@ impl WgpuBackend {
188196
})
189197
}
190198

191-
fn bytes_per_row(width: u16) -> u16 {
199+
fn bytes_per_row(&self, width: u16) -> u16 {
192200
let row_size = width * 4;
193201
(row_size + 255) & !255
194202
}
195203

196-
pub fn row_padding(width: u16) -> u16 {
204+
fn row_padding(&self, width: u16) -> u16 {
197205
let row_size = width * 4;
198-
let bytes_per_row = Self::bytes_per_row(width);
206+
let bytes_per_row = self.bytes_per_row(width);
199207
(bytes_per_row - row_size) / 4
200208
}
201209

202-
async fn execute_inner(&mut self, width: u16, height: u16) -> Vec<[u8; 4]> {
203-
if Self::bytes_per_row(width) != Self::bytes_per_row(self.width) || height != self.height {
210+
async fn execute_inner(&mut self, width: u16, height: u16, _user_data: &T) -> Vec<[u8; 4]> {
211+
if self.bytes_per_row(width) != self.bytes_per_row(self.width) || height != self.height {
204212
self.texture = Self::create_texture(&self.device, width.into(), height.into());
205213
self.output_buffer = Self::create_buffer(&self.device, width.into(), height.into());
206214
}
207-
208-
let bytes_per_row = Self::bytes_per_row(width);
215+
let bytes_per_row = self.bytes_per_row(width);
209216

210217
let texture_view = self
211218
.texture
@@ -283,23 +290,34 @@ impl WgpuBackend {
283290
.await
284291
.expect("unable to receive message all senders have been dropped")
285292
.expect("on unexpected error occured");
286-
let slice: Vec<[u8; 4]>;
293+
let padded_buffer: Vec<[u8; 4]>;
287294
{
288295
let view = buffer_slice.get_mapped_range();
289-
slice = bytemuck::cast_slice(&view).to_vec();
296+
padded_buffer = bytemuck::cast_slice(&view).to_vec();
290297
}
291298
self.output_buffer.unmap();
292-
slice
299+
let mut buffer: Vec<[u8; 4]> = Vec::new();
300+
for y in 0..height {
301+
for x in 0..width {
302+
let index = (y * (width + self.row_padding(width)) + x) as usize;
303+
let pixel = padded_buffer[index];
304+
buffer.push(pixel);
305+
}
306+
}
307+
buffer
293308
}
294309
}
295310

296-
impl TuiShaderBackend for WgpuBackend {
297-
fn execute(&mut self, width: u16, height: u16) -> Vec<[u8; 4]> {
298-
self.execute_inner(width, height).block_on()
311+
impl<T> TuiShaderBackend<T> for WgpuBackend<T>
312+
where
313+
T: Copy + Clone + bytemuck::Pod + bytemuck::Zeroable,
314+
{
315+
fn execute(&mut self, width: u16, height: u16, user_data: &T) -> Vec<[u8; 4]> {
316+
self.execute_inner(width, height, user_data).block_on()
299317
}
300318
}
301319

302-
impl Default for WgpuBackend {
320+
impl Default for WgpuBackend<NoUserData> {
303321
fn default() -> Self {
304322
Self::new("src/shaders/default_fragment.wgsl", "magenta")
305323
}

src/lib.rs

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ mod backend;
3131

3232
use std::marker::PhantomData;
3333

34-
use backend::TuiShaderBackend;
34+
use backend::cpu::CpuBackend;
35+
use backend::{NoUserData, TuiShaderBackend};
3536
use ratatui::layout::{Position, Rect};
3637
use ratatui::style::{Color, Style};
3738
use ratatui::widgets::StatefulWidget;
@@ -59,17 +60,17 @@ pub struct ShaderCanvas<T> {
5960
pub character_rule: CharacterRule,
6061
pub style_rule: StyleRule,
6162
pub entry_point: String,
62-
_marker: PhantomData<T>,
63+
_user_data: PhantomData<T>,
6364
}
6465

65-
impl<T: TuiShaderBackend> ShaderCanvas<T> {
66+
impl<T> ShaderCanvas<T> {
6667
/// Creates a new instance of [`ShaderCanvas`].
6768
pub fn new() -> Self {
6869
Self {
6970
character_rule: CharacterRule::default(),
7071
style_rule: StyleRule::default(),
7172
entry_point: String::from("main"),
72-
_marker: PhantomData,
73+
_user_data: PhantomData,
7374
}
7475
}
7576

@@ -96,13 +97,13 @@ impl<T: TuiShaderBackend> ShaderCanvas<T> {
9697
}
9798
}
9899

99-
impl<T: TuiShaderBackend> Default for ShaderCanvas<T> {
100+
impl<T> Default for ShaderCanvas<T> {
100101
fn default() -> Self {
101102
Self::new()
102103
}
103104
}
104105

105-
impl<T: TuiShaderBackend> StatefulWidget for ShaderCanvas<T> {
106+
impl<T> StatefulWidget for ShaderCanvas<T> {
106107
type State = ShaderCanvasState<T>;
107108
fn render(
108109
self,
@@ -112,13 +113,12 @@ impl<T: TuiShaderBackend> StatefulWidget for ShaderCanvas<T> {
112113
) {
113114
let width = area.width;
114115
let height = area.height;
115-
116-
let raw_buffer = state.backend.execute(width, height);
116+
let samples = state.backend.execute(width, height, &state.user_data);
117117

118118
for y in 0..height {
119119
for x in 0..width {
120-
let index = (y * (width + WgpuBackend::row_padding(width)) + x) as usize;
121-
let value = raw_buffer[index];
120+
let index = (y * width + x) as usize;
121+
let value = samples[index];
122122
let position = (x, y);
123123
let character = match self.character_rule {
124124
CharacterRule::Always(character) => character,
@@ -142,35 +142,69 @@ impl<T: TuiShaderBackend> StatefulWidget for ShaderCanvas<T> {
142142
}
143143

144144
/// State struct for [`ShaderCanvas`], it holds the [`TuiShaderBackend`].
145-
#[derive(Debug, Clone, Eq, PartialEq)]
146-
pub struct ShaderCanvasState<T: TuiShaderBackend> {
147-
backend: T,
145+
pub struct ShaderCanvasState<T> {
146+
backend: Box<dyn TuiShaderBackend<T>>,
147+
user_data: T,
148148
}
149149

150-
impl ShaderCanvasState<WgpuBackend> {
150+
impl ShaderCanvasState<NoUserData> {
151151
/// Creates a new [`ShaderCanvasState`] using [`WgpuBackend`] as it's
152152
/// [`TuiShaderBackend`].
153-
pub fn wgpu(
153+
pub fn wgpu(path_to_fragment_shader: &str, entry_point: &str) -> Self {
154+
let backend = Box::new(WgpuBackend::new(path_to_fragment_shader, entry_point));
155+
Self {
156+
backend,
157+
user_data: NoUserData,
158+
}
159+
}
160+
}
161+
162+
impl<T> ShaderCanvasState<T>
163+
where
164+
T: Copy + bytemuck::Pod + bytemuck::Zeroable,
165+
{
166+
pub fn wgpu_with_user_data(
154167
path_to_fragment_shader: &str,
155168
entry_point: &str,
156-
) -> ShaderCanvasState<WgpuBackend> {
157-
let backend = WgpuBackend::new(path_to_fragment_shader, entry_point);
158-
ShaderCanvasState { backend }
169+
user_data: T,
170+
) -> Self {
171+
let backend = Box::new(WgpuBackend::new(path_to_fragment_shader, entry_point));
172+
Self { backend, user_data }
159173
}
160174
}
161175

162-
impl<T: TuiShaderBackend> ShaderCanvasState<T> {
163-
/// Creates a new [`ShaderCanvasState`] instance by passing in the desired [`TuiShaderBackend`].
164-
pub fn new(backend: T) -> ShaderCanvasState<T> {
165-
ShaderCanvasState { backend }
176+
impl ShaderCanvasState<NoUserData> {
177+
pub fn cpu<F>(callback: F) -> Self
178+
where
179+
F: Fn(u16, u16) -> [u8; 4] + 'static,
180+
{
181+
let backend = Box::new(CpuBackend::new(callback));
182+
Self {
183+
backend,
184+
user_data: NoUserData,
185+
}
186+
}
187+
}
188+
189+
impl<T> ShaderCanvasState<T>
190+
where
191+
T: 'static,
192+
{
193+
pub fn cpu_with_user_data<F>(callback: F, user_data: T) -> Self
194+
where
195+
F: Fn(u16, u16, &T) -> [u8; 4] + 'static,
196+
{
197+
let backend = Box::new(CpuBackend::new_with_user_data(callback));
198+
Self { backend, user_data }
166199
}
167200
}
168201

169-
impl Default for ShaderCanvasState<WgpuBackend> {
202+
impl Default for ShaderCanvasState<NoUserData> {
170203
/// Creates a new [`ShaderCanvasState`] instance with a [`WgpuBackend`].
171-
fn default() -> ShaderCanvasState<WgpuBackend> {
204+
fn default() -> Self {
172205
Self {
173-
backend: WgpuBackend::default(),
206+
backend: Box::new(WgpuBackend::default()),
207+
user_data: NoUserData,
174208
}
175209
}
176210
}
@@ -262,14 +296,14 @@ mod tests {
262296
#[test]
263297
fn default_wgsl_context() {
264298
let mut context = WgpuBackend::default();
265-
let raw_buffer = context.execute(64, 64);
299+
let raw_buffer = context.execute(64, 64, &NoUserData);
266300
assert!(raw_buffer.iter().all(|pixel| pixel == &[255, 0, 255, 255]));
267301
}
268302

269303
#[test]
270304
fn different_entry_points() {
271305
let mut context = WgpuBackend::new("src/shaders/default_fragment.wgsl", "green");
272-
let raw_buffer = context.execute(64, 64);
306+
let raw_buffer = context.execute(64, 64, &NoUserData);
273307
assert!(raw_buffer.iter().all(|pixel| pixel == &[0, 255, 0, 255]));
274308
}
275309

0 commit comments

Comments
 (0)