Skip to content

Commit 9179071

Browse files
committed
Add support for MPS acceleration and float16 dtype models on chapel-webcam demo.
1 parent 5709906 commit 9179071

6 files changed

Lines changed: 189 additions & 23 deletions

File tree

bridge/lib/bridge.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
6565
return torch::from_blob(bt.data, shape, torch::kFloat);
6666
}
6767

68+
torch::Tensor bridge_to_torch(bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32) {
69+
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
70+
auto shape = torch::IntArrayRef(sizes_vec);
71+
auto t = torch::from_blob(bt.data, shape, torch::kFloat);
72+
if (device != torch::kCPU)
73+
copy = true;
74+
if (copy)
75+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/true);
76+
else
77+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/false);
78+
79+
}
80+
6881
extern "C" float32_t* unsafe(const float32_t* arr) {
6982
return const_cast<float32_t*>(arr);
7083
}
@@ -142,7 +155,7 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
142155
try {
143156

144157
auto* module = new torch::jit::Module(torch::jit::load(path));
145-
module->to(torch::kCPU);
158+
module->to(torch::kMPS,torch::kFloat16,false);
146159
module->eval();
147160
std::cout << "Model loaded successfully!" << std::endl;
148161
std::cout.flush();
@@ -190,20 +203,19 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
190203

191204
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
192205

193-
auto tn = bridge_to_torch(input).clone();
194-
auto tn_ = tn.permute({2, 0, 1}).unsqueeze(0).contiguous();
206+
auto tn_mps = bridge_to_torch(input,torch::kMPS,true,torch::kFloat16);
207+
// auto tn_mps = tn.to(torch::kMPS,false,true);
208+
auto tn_ = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
195209

196210
std::vector<torch::jit::IValue> ins;
197211
ins.push_back(tn_);
198212

199213
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
200214
auto o = module->forward(ins).toTensor();
201215
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
216+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
217+
return torch_to_bridge(tn_out_cpu);
202218

203-
return torch_to_bridge(tn_out);
204-
205-
206-
//
207219
/*
208220
209221
auto t = bridge_to_torch(input).clone();

demos/video/chapel-webcam/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ OPENCV_LDFLAGS := $(shell pkg-config --cflags --libs opencv4)
6565

6666
main: main.cpp smol_wrapper.h lib/libsmol.so
6767
@echo $(OPENCV_CFLAGS)
68-
$(CHPL_LINKER) $(CHPL_CFLAGS) $(OPENCV_CFLAGS) $(BRIDGE_CFLAGS) -O2 -std=c++20 -fPIC main.cpp -o main $(CHPL_LDFLAGS) $(OPENCV_LDFLAGS) $(BRIDGE_LDFLAGS)
68+
$(CHPL_LINKER) $(CHPL_CFLAGS) $(OPENCV_CFLAGS) $(BRIDGE_CFLAGS) -std=c++20 -fPIC main.cpp -o main $(CHPL_LDFLAGS) $(OPENCV_LDFLAGS) $(BRIDGE_LDFLAGS)
6969

7070
clean:
7171
rm -f maincpp maincpp.o main.o main

demos/video/chapel-webcam/model.ipynb

Lines changed: 168 additions & 14 deletions
Large diffs are not rendered by default.

demos/video/chapel-webcam/smol.chpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ proc getTime() {
3333

3434
const startTime = getTime();
3535

36-
config const modelPath: string = "../style-transfer/models/exports/cpu/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
36+
config const modelPath: string = "sobel.pt"; // "../style-transfer/models/exports/cpu/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
3737
var model : Bridge.bridge_pt_model_t;
3838

3939
use CTypes;

demos/video/chapel-webcam/sobel.pt

0 Bytes
Binary file not shown.
6.47 MB
Binary file not shown.

0 commit comments

Comments
 (0)