Skip to content

Commit 840d3bd

Browse files
committed
chapel-webcam working except for torch model loading. Mirror screen and other things work. Probably issue with loading torch script across chapel threads.
1 parent d5423a8 commit 840d3bd

2 files changed

Lines changed: 49 additions & 18 deletions

File tree

bridge/lib/bridge.cpp

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,19 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
140140
std::cout.flush();
141141

142142
try {
143-
torch::jit::Module tmp = torch::jit::load(path);
143+
144+
auto* module = new torch::jit::Module(torch::jit::load(path));
144145
std::cout << "Model loaded successfully!" << std::endl;
145146
std::cout.flush();
146-
auto* module = new torch::jit::Module(std::move(tmp));
147-
std::cout << "Model moved successfully!" << std::endl;
148-
std::cout.flush();
149147
return { static_cast<void*>(module) };
148+
149+
// torch::jit::Module tmp = torch::jit::load(path);
150+
// std::cout << "Model loaded successfully!" << std::endl;
151+
// std::cout.flush();
152+
// auto* module = new torch::jit::Module(std::move(tmp));
153+
// std::cout << "Model moved successfully!" << std::endl;
154+
// std::cout.flush();
155+
// return { static_cast<void*>(module) };
150156
} catch (const c10::Error& e) {
151157
std::cerr << "error loading the model\n" << e.msg();
152158
std::cout << "error loading the model\n" << e.msg();
@@ -190,44 +196,68 @@ extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_
190196
inputs.push_back(t_input);
191197
// torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
192198
// auto output = pt_module->forward(inputs).toTensor();
193-
auto output = t_input;
199+
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
200+
auto output = module->forward(inputs).toTensor();
201+
std::cout << "Output tensor: " << output.sizes() << std::endl;
202+
std::cout.flush();
203+
// auto output = t_input;
194204
return torch_to_bridge(output);
195205
}
196206

197207
extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input) {
198208
auto input_tensor = bridge_to_torch(input);
199-
auto t_input = input_tensor.clone();
209+
auto input_tensor_copy = input_tensor.clone().contiguous();
210+
auto t_input = input_tensor_copy;
211+
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
212+
213+
std::cout << "Model: " << module->dump_to_str(false, false, false) << std::endl;
214+
std::cout.flush();
200215

201216
std::cout << "Input tensor: " << t_input.sizes() << std::endl;
202217
std::cout.flush();
203218

204-
t_input = t_input.permute({2, 0, 1}).unsqueeze(0);
219+
auto model_input = input_tensor_copy.permute({2, 0, 1}).unsqueeze(0);
205220

206-
std::cout << "Input tensor reshaped: " << t_input.sizes() << std::endl;
207-
std::cout.flush();
221+
// std::cout << "Input tensor reshaped: " << model_input.sizes() << std::endl;
222+
// std::cout.flush();
223+
224+
// std::vector<torch::jit::IValue> inputs;
225+
// inputs.push_back(model_input);
226+
227+
// std::cout << "Constructed inputs: " << inputs.size() << std::endl;
228+
// std::cout.flush();
229+
230+
// return torch_to_bridge(input_tensor_copy);
208231

209232
std::vector<torch::jit::IValue> inputs;
210-
inputs.push_back(t_input);
233+
inputs.push_back(model_input);
211234

212-
std::cout << "Constructed inputs: " << inputs.size() << std::endl;
235+
std::cout << "Model input: " << model_input.sizes() << std::endl;
213236
std::cout.flush();
214237

238+
auto model_output = module->forward(inputs).toTensor();
239+
std::cout << "Output tensor: " << model_output.sizes() << std::endl;
240+
std::cout.flush();
241+
242+
auto output = model_output.div(255.0).squeeze(0).permute({1, 2, 0}).clamp(0, 1);
243+
return torch_to_bridge(output);
244+
215245
// torch::jit::script::Module & pt_module = model.pt_module;
216246

217-
auto* pt_module = static_cast<torch::jit::Module*>(model.pt_module);
247+
// auto* pt_module = static_cast<torch::jit::Module*>(model.pt_module);
218248

219-
// torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
220-
// std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
221-
// // std::cout.flush();
222-
auto output = pt_module->forward(inputs).toTensor();
249+
// // torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
250+
// // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
251+
// // // std::cout.flush();
252+
// auto output = pt_module->forward(inputs).toTensor();
223253
std::cout << "Output tensor: " << output.sizes() << std::endl;
224254
std::cout.flush();
225255
// output = output.squeeze(0).permute({1, 2, 0}).clamp(0, 1).mul(255.0);
226256

227257
// std::cout << "Processed utput tensor: " << output.sizes() << std::endl;
228258
// std::cout.flush();
229259

230-
return torch_to_bridge(input_tensor);
260+
return torch_to_bridge(input_tensor_copy);
231261
}
232262

233263

demos/video/chapel-webcam/smol.chpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
5858
// writeln("Copying data 1");
5959
// ndframe = bt : ndframe.type;
6060

61-
var bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
61+
// var bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
62+
var bt = Bridge.model_forward_style_transfer(model,ndframe : Bridge.tensorHandle(real(32)));
6263
writeln("Copying data 1");
6364
ndframe.loadFromBridgeTensor(bt);
6465

0 commit comments

Comments
 (0)