Skip to content

Commit b057c16

Browse files
committed
Sobel model working with chapel-webcam demo.
1 parent 840d3bd commit b057c16

7 files changed

Lines changed: 326 additions & 59 deletions

File tree

bridge/lib/bridge.cpp

Lines changed: 38 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
142142
try {
143143

144144
auto* module = new torch::jit::Module(torch::jit::load(path));
145+
module->to(torch::kCPU);
146+
module->eval();
145147
std::cout << "Model loaded successfully!" << std::endl;
146148
std::cout.flush();
147149
return { static_cast<void*>(module) };
@@ -187,7 +189,25 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
187189
}
188190

189191
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
190-
auto t_input = bridge_to_torch(input);
192+
193+
auto tn = bridge_to_torch(input).clone();
194+
auto tn_ = tn.permute({2, 0, 1}).unsqueeze(0).contiguous();
195+
196+
std::vector<torch::jit::IValue> ins;
197+
ins.push_back(tn_);
198+
199+
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
200+
auto o = module->forward(ins).toTensor();
201+
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
202+
203+
return torch_to_bridge(tn_out);
204+
205+
206+
//
207+
/*
208+
209+
auto t = bridge_to_torch(input).clone();
210+
auto t_input = t.permute({2, 0, 1}).unsqueeze(0); // Add batch dimension
191211
192212
std::cout << "Input tensor: " << t_input.sizes() << std::endl;
193213
std::cout.flush();
@@ -200,64 +220,35 @@ extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_
200220
auto output = module->forward(inputs).toTensor();
201221
std::cout << "Output tensor: " << output.sizes() << std::endl;
202222
std::cout.flush();
223+
224+
auto output_reshaped = output.squeeze(0).permute({1, 2, 0}); // Remove batch dimension and permute back to HWC
225+
std::cout << "Output reshaped tensor: " << output_reshaped.sizes() << std::endl;
226+
std::cout.flush();
203227
// auto output = t_input;
204-
return torch_to_bridge(output);
228+
return torch_to_bridge(output_reshaped);
229+
*/
205230
}
206231

207232
extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input) {
208-
auto input_tensor = bridge_to_torch(input);
209-
auto input_tensor_copy = input_tensor.clone().contiguous();
210-
auto t_input = input_tensor_copy;
211-
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
212-
213-
std::cout << "Model: " << module->dump_to_str(false, false, false) << std::endl;
214-
std::cout.flush();
233+
auto bt = bridge_to_torch(input).clone();
234+
auto t_input = bt.permute({2, 0, 1}).unsqueeze(0); // Convert from CHW to HWC
215235

216236
std::cout << "Input tensor: " << t_input.sizes() << std::endl;
217237
std::cout.flush();
218238

219-
auto model_input = input_tensor_copy.permute({2, 0, 1}).unsqueeze(0);
220-
221-
// std::cout << "Input tensor reshaped: " << model_input.sizes() << std::endl;
222-
// std::cout.flush();
223-
224-
// std::vector<torch::jit::IValue> inputs;
225-
// inputs.push_back(model_input);
226-
227-
// std::cout << "Constructed inputs: " << inputs.size() << std::endl;
228-
// std::cout.flush();
229-
230-
// return torch_to_bridge(input_tensor_copy);
231-
232239
std::vector<torch::jit::IValue> inputs;
233-
inputs.push_back(model_input);
234-
235-
std::cout << "Model input: " << model_input.sizes() << std::endl;
236-
std::cout.flush();
237-
238-
auto model_output = module->forward(inputs).toTensor();
239-
std::cout << "Output tensor: " << model_output.sizes() << std::endl;
240-
std::cout.flush();
241-
242-
auto output = model_output.div(255.0).squeeze(0).permute({1, 2, 0}).clamp(0, 1);
243-
return torch_to_bridge(output);
244-
245-
// torch::jit::script::Module & pt_module = model.pt_module;
246-
247-
// auto* pt_module = static_cast<torch::jit::Module*>(model.pt_module);
248-
249-
// // torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
250-
// // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
251-
// // // std::cout.flush();
240+
inputs.push_back(t_input);
241+
// torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
252242
// auto output = pt_module->forward(inputs).toTensor();
243+
// auto* module = static_cast<torch::jit::Module*>(model.pt_module);
244+
auto module = *static_cast<torch::jit::Module*>(model.pt_module);
245+
std::cout << "Module: " << module.dump_to_str(false, false, false) << std::endl;
246+
std::cout.flush();
247+
auto output = module.forward(inputs).toTensor();
253248
std::cout << "Output tensor: " << output.sizes() << std::endl;
254249
std::cout.flush();
255-
// output = output.squeeze(0).permute({1, 2, 0}).clamp(0, 1).mul(255.0);
256-
257-
// std::cout << "Processed utput tensor: " << output.sizes() << std::endl;
258-
// std::cout.flush();
259-
260-
return torch_to_bridge(input_tensor_copy);
250+
// auto output = t_input;
251+
return torch_to_bridge(output);
261252
}
262253

263254

675 KB
Loading

demos/video/chapel-webcam/main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ cv::Mat new_frame(cv::Mat &frame) {
2020

2121
// cv::MatSize size = rgb_frame.size;
2222
// std::cout << "x " << size[0] << " y " << size[1] << " channels " << rgb_frame.dims << std::endl;
23-
int64_t width = rgb_float_frame.cols;
2423
int64_t height = rgb_float_frame.rows;
24+
int64_t width = rgb_float_frame.cols;
2525
int64_t channels = rgb_float_frame.channels();
2626
int64_t pixels = rgb_float_frame.total();
2727
int64_t size = pixels * channels;
2828

29+
std::cout << "Width: " << width << ", Height: " << height << ", Channels: " << channels << ", Size: " << size << std::endl;
30+
2931
chpl_external_array
3032
rgb_float_frame_data_ptr = chpl_make_external_array_ptr(rgb_float_frame.data,size);
3133

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "657a8f27",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "95f3b45d",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"class CustomModel(torch.nn.Module):\n",
21+
" def __init__(self):\n",
22+
" super(CustomModel, self).__init__()\n",
23+
" self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)\n",
24+
" self.flatten = torch.nn.Flatten()\n",
25+
" self.fc = torch.nn.Linear(16 * 1920 * 1080, 10)\n",
26+
"\n",
27+
" def forward(self, x):\n",
28+
" return self.fc(self.flatten(self.conv1(x)))\n",
29+
"\n",
30+
"model = CustomModel()\n",
31+
"model.eval()\n",
32+
"\n",
33+
"sm = torch.jit.script(model.to(torch.float32))\n",
34+
"sm.save(f\"model.pt\")\n"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"id": "6927bdb1",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"x = torch.randn(1, 3, 1920, 1080).to(torch.float32)\n",
45+
"y = model(x)\n",
46+
"y.shape"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"id": "72df7a4f",
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"x = torch.randn(1, 3, 1080, 1920).to(torch.float32)\n",
57+
"y = model(x)"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"id": "a8cca72a",
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"class Sobel(torch.nn.Module):\n",
68+
" def __init__(self):\n",
69+
" super(Sobel, self).__init__()\n",
70+
" sobel_dx = torch.tensor([[-1, 0, 1],\n",
71+
" [-2, 0, 2],\n",
72+
" [-1, 0, 1]], dtype=torch.float32)\n",
73+
"\n",
74+
" sobel_dy = torch.tensor([[-1, -2, -1],\n",
75+
" [ 0, 0, 0],\n",
76+
" [ 1, 2, 1]], dtype=torch.float32)\n",
77+
"\n",
78+
" kernel_dx = sobel_dx.view(1,1,3,3).repeat(3,1,1,1).contiguous()\n",
79+
" kernel_dy = sobel_dy.view(1,1,3,3).repeat(3,1,1,1).contiguous()\n",
80+
"\n",
81+
" self.kernel_dx = torch.nn.Parameter(kernel_dx, requires_grad=False)\n",
82+
" self.kernel_dy = torch.nn.Parameter(kernel_dy, requires_grad=False)\n",
83+
"\n",
84+
"\n",
85+
" # sobel_kernel = torch.stack([sobel_dx, sobel_dy]) # [2,3,3]\n",
86+
" # sobel_kernel = sobel_kernel.unsqueeze(1).repeat(1, 3, 1, 1) # [2,3,3,3]\n",
87+
" # sobel_kernel = sobel_kernel.to(torch.float32)\n",
88+
"\n",
89+
" # self.sobel_kernel = torch.nn.Parameter(sobel_kernel, requires_grad=False)\n",
90+
" # # self.sobel_cnn = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False).to(torch.float16)\n",
91+
" # # self.sobel_cnn.weight = torch.nn.Parameter(sobel_kernel, requires_grad=False)\n",
92+
"\n",
93+
" def forward(self, x):\n",
94+
" # return self.sobel_cnn(x)\n",
95+
" grad_x = torch.nn.functional.conv2d(x, self.kernel_dx, padding=1, groups=3)\n",
96+
" grad_y = torch.nn.functional.conv2d(x, self.kernel_dy, padding=1, groups=3)\n",
97+
"\n",
98+
" grad_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2)\n",
99+
" return grad_mag\n",
100+
"\n",
101+
" # return torch.nn.functional.conv2d(x, self.sobel_kernel, padding=1,stride=1)\n",
102+
"\n",
103+
"\n",
104+
"\n",
105+
"# sobel = Sobel().to('mps').to(torch.float32)\n",
106+
"# sm = torch.jit.script(sobel)\n",
107+
"# sm.save(\"models/sobel_float32.pt\")\n",
108+
"# sobel = Sobel().to('mps').to(torch.float16)\n",
109+
"# sm = torch.jit.script(sobel)\n",
110+
"# sm.save(\"models/sobel_float16.pt\")"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"id": "1ce18cda",
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"sobel = Sobel().to(torch.float32)\n",
121+
"sobel.eval()\n",
122+
"sm = torch.jit.script(sobel)\n",
123+
"sm.save(\"sobel.pt\")"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"id": "8ec3d21a",
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"from torchvision.io import decode_image\n",
134+
"from PIL import Image\n",
135+
"import matplotlib.pyplot as plt\n",
136+
"def show_image(image):\n",
137+
" plt.imshow(image.permute(1, 2, 0).cpu())\n",
138+
" plt.axis('off')\n",
139+
" plt.show()\n",
140+
"\n",
141+
"\n",
142+
"img = decode_image('coast.jpeg', mode='RGB')\n",
143+
"img = img.to(torch.float32) / 255.0\n",
144+
"# plt.imshow(img)\n",
145+
"# plt.show()\n",
146+
"print(img.shape)\n",
147+
"print(img.permute(1, 2, 0).shape)\n",
148+
"# show_image(img)"
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"id": "5780a986",
155+
"metadata": {},
156+
"outputs": [],
157+
"source": [
158+
"sobel_img = sobel(img.unsqueeze(0))\n",
159+
"print(img.shape)\n",
160+
"print(sobel_img.shape)\n",
161+
"show_image(sobel_img.squeeze(0))"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"id": "7e2e6286",
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"\n",
172+
"\n"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"id": "f353fa03",
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"def chat_sobel(img):\n",
183+
" import torch\n",
184+
" import torch.nn.functional as F\n",
185+
"\n",
186+
" # assume img is a FloatTensor of shape [3, H, W], e.g. normalized to [0,1]\n",
187+
" # step 0: add batch dim\n",
188+
" img = img.unsqueeze(0) # now [1, 3, H, W]\n",
189+
"\n",
190+
" # 1. define 2D Sobel kernels\n",
191+
" sobel_dx = torch.tensor([[-1., 0., 1.],\n",
192+
" [-2., 0., 2.],\n",
193+
" [-1., 0., 1.]], dtype=torch.float32)\n",
194+
" sobel_dy = torch.tensor([[-1., -2., -1.],\n",
195+
" [ 0., 0., 0.],\n",
196+
" [ 1., 2., 1.]], dtype=torch.float32)\n",
197+
"\n",
198+
" # 2. reshape them into conv filters of shape (out_ch, in_ch_per_group, kH, kW)\n",
199+
" # here we want 1 filter per input channel, done 3 times (one group per channel)\n",
200+
" kernel_dx = sobel_dx.view(1,1,3,3).repeat(3,1,1,1) # → (3,1,3,3)\n",
201+
" kernel_dy = sobel_dy.view(1,1,3,3).repeat(3,1,1,1) # → (3,1,3,3)\n",
202+
"\n",
203+
" # 3. apply grouped conv so each channel is convolved separately\n",
204+
" grad_x = F.conv2d(img, kernel_dx, padding=1, groups=3) # [1,3,H,W]\n",
205+
" grad_y = F.conv2d(img, kernel_dy, padding=1, groups=3) # [1,3,H,W]\n",
206+
"\n",
207+
" # 4. compute magnitude per channel\n",
208+
" grad_mag = torch.sqrt(grad_x**2 + grad_y**2) # [1,3,H,W]\n",
209+
"\n",
210+
" # 5. squeeze off the batch dim\n",
211+
" out_img = grad_mag.squeeze(0) # → [3, H, W]\n",
212+
"\n",
213+
" print(out_img.shape) # torch.Size([3, 1080, 1920])\n",
214+
" return out_img\n",
215+
"\n",
216+
"img2 = img.clone()\n",
217+
"out_img = chat_sobel(img2)\n",
218+
"show_image(out_img)"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": null,
224+
"id": "ad8bee17",
225+
"metadata": {},
226+
"outputs": [],
227+
"source": []
228+
}
229+
],
230+
"metadata": {
231+
"kernelspec": {
232+
"display_name": ".venv",
233+
"language": "python",
234+
"name": "python3"
235+
},
236+
"language_info": {
237+
"codemirror_mode": {
238+
"name": "ipython",
239+
"version": 3
240+
},
241+
"file_extension": ".py",
242+
"mimetype": "text/x-python",
243+
"name": "python",
244+
"nbconvert_exporter": "python",
245+
"pygments_lexer": "ipython3",
246+
"version": "3.12.9"
247+
}
248+
},
249+
"nbformat": 4,
250+
"nbformat_minor": 5
251+
}

0 commit comments

Comments
 (0)