Skip to content

Commit b129992

Browse files
committed
Chapel-webcam demo working with starry night.
1 parent 9179071 commit b129992

4 files changed

Lines changed: 38 additions & 67 deletions

File tree

bridge/include/bridge.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_
5050
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
5151

5252
bridge_pt_model_t load_model(const uint8_t* model_path);
53+
5354
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
55+
56+
5457
bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input);
5558

5659
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);

bridge/lib/bridge.cpp

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -201,66 +201,36 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
201201
// return model_wrapper;
202202
}
203203

204-
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
204+
205+
206+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
205207

206208
auto tn_mps = bridge_to_torch(input,torch::kMPS,true,torch::kFloat16);
207209
// auto tn_mps = tn.to(torch::kMPS,false,true);
208-
auto tn_ = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
210+
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
209211

210212
std::vector<torch::jit::IValue> ins;
211-
ins.push_back(tn_);
213+
ins.push_back(tn);
212214

213215
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
214216
auto o = module->forward(ins).toTensor();
215217
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
216-
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
217-
return torch_to_bridge(tn_out_cpu);
218-
219-
/*
220218

221-
auto t = bridge_to_torch(input).clone();
222-
auto t_input = t.permute({2, 0, 1}).unsqueeze(0); // Add batch dimension
219+
if (is_vgg_based_model) {
220+
tn_out = tn_out / 255.0;
221+
}
223222

224-
std::cout << "Input tensor: " << t_input.sizes() << std::endl;
225-
std::cout.flush();
223+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
224+
return torch_to_bridge(tn_out_cpu);
226225

227-
std::vector<torch::jit::IValue> inputs;
228-
inputs.push_back(t_input);
229-
// torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
230-
// auto output = pt_module->forward(inputs).toTensor();
231-
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
232-
auto output = module->forward(inputs).toTensor();
233-
std::cout << "Output tensor: " << output.sizes() << std::endl;
234-
std::cout.flush();
226+
}
235227

236-
auto output_reshaped = output.squeeze(0).permute({1, 2, 0}); // Remove batch dimension and permute back to HWC
237-
std::cout << "Output reshaped tensor: " << output_reshaped.sizes() << std::endl;
238-
std::cout.flush();
239-
// auto output = t_input;
240-
return torch_to_bridge(output_reshaped);
241-
*/
228+
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
229+
return model_forward(model, input, false);
242230
}
243231

244232
extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input) {
245-
auto bt = bridge_to_torch(input).clone();
246-
auto t_input = bt.permute({2, 0, 1}).unsqueeze(0); // Convert from CHW to HWC
247-
248-
std::cout << "Input tensor: " << t_input.sizes() << std::endl;
249-
std::cout.flush();
250-
251-
std::vector<torch::jit::IValue> inputs;
252-
inputs.push_back(t_input);
253-
// torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
254-
// auto output = pt_module->forward(inputs).toTensor();
255-
// auto* module = static_cast<torch::jit::Module*>(model.pt_module);
256-
auto module = *static_cast<torch::jit::Module*>(model.pt_module);
257-
std::cout << "Module: " << module.dump_to_str(false, false, false) << std::endl;
258-
std::cout.flush();
259-
auto output = module.forward(inputs).toTensor();
260-
std::cout << "Output tensor: " << output.sizes() << std::endl;
261-
std::cout.flush();
262-
// auto output = t_input;
263-
return torch_to_bridge(output);
233+
return model_forward(model, input, true);
264234
}
265235

266236

demos/video/chapel-webcam/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ OPENCV_LDFLAGS := $(shell pkg-config --cflags --libs opencv4)
6565

6666
main: main.cpp smol_wrapper.h lib/libsmol.so
6767
@echo $(OPENCV_CFLAGS)
68-
$(CHPL_LINKER) $(CHPL_CFLAGS) $(OPENCV_CFLAGS) $(BRIDGE_CFLAGS) -std=c++20 -fPIC main.cpp -o main $(CHPL_LDFLAGS) $(OPENCV_LDFLAGS) $(BRIDGE_LDFLAGS)
68+
$(CHPL_LINKER) $(CHPL_CFLAGS) $(OPENCV_CFLAGS) $(BRIDGE_CFLAGS) -O2 -std=c++20 -fPIC main.cpp -o main $(CHPL_LDFLAGS) $(OPENCV_LDFLAGS) $(BRIDGE_LDFLAGS)
6969

7070
clean:
7171
rm -f maincpp maincpp.o main.o main

demos/video/chapel-webcam/smol.chpl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,63 +33,61 @@ proc getTime() {
3333

3434
const startTime = getTime();
3535

36-
config const modelPath: string = "sobel.pt"; // "../style-transfer/models/exports/cpu/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
36+
config const modelPath: string = "../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
3737
var model : Bridge.bridge_pt_model_t;
3838

3939
use CTypes;
4040

4141
export proc globalLoadModel() {
4242
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
4343
model = Bridge.load_model(fpPtr);
44+
45+
// const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
46+
// var model = Bridge.load_model(fpPtr);
4447
}
4548

4649

50+
var lastFrame = startTime;
4751

4852
export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels: int): [] real(32) {
4953

5054
const t = getTime() - startTime;
55+
const dt = getTime() - lastFrame;
56+
writeln("FPS: ", 1.0 / dt);
5157
const shape = (height,width,channels);
52-
writeln(shape);
53-
writeln("frame size: ",frame.size);
54-
5558
const frameDom = utils.domainFromShape((...shape));
56-
5759
const shapedFrame = [idx in frameDom] frame[utils.linearIdx(shape,idx)];
5860

5961
var ndframe = new ndarray(shapedFrame);
60-
writeln("ndframe shape: ", ndframe.shape);
6162

6263
// const nf = ndframe.flatten().data;
6364
// return nf;
64-
6565
// var ndframe = new ndarray(real(32),shape);
6666
// ndframe.data = reshape(frame,ndframe.domain);
67-
6867
// var bt = Bridge.model_forward_style_transfer(model,ndframe : Bridge.tensorHandle(real(32)));
6968
// writeln("Copying data 1");
7069
// ndframe = bt : ndframe.type;
7170

7271
// var bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
73-
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
74-
var model = Bridge.load_model(fpPtr);
75-
var bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
7672

77-
// var bt = Bridge.model_forward_style_transfer(model,ndframe : Bridge.tensorHandle(real(32)));
78-
writeln("Copying data 1");
73+
var bt: Bridge.bridge_tensor_t;
74+
if modelPath == "sobel.pt" then
75+
bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
76+
else
77+
bt = Bridge.model_forward_style_transfer(model,ndframe : Bridge.tensorHandle(real(32)));
78+
7979
// ndframe.loadFromBridgeTensor(bt);
8080
const nextNDFrame = bt : ndarray(3, real(32));
8181

82-
writeln("nextFrame shape: ", nextNDFrame.shape);
83-
84-
writeln("Copying data 2");
82+
forall i in 0..<frame.size {
83+
const idx = utils.indexAt(i,(...shape));
84+
ref color = frame[i];
85+
color = nextNDFrame.data[idx];
86+
}
8587

86-
// forall i in 0..<frame.size {
87-
// const idx = utils.indexAt(i,(...shape));
88-
// ref color = frame[i];
89-
// color = nextNDFrame.data[idx];
90-
// }
88+
lastFrame = getTime();
9189

92-
// return frame;
90+
return frame;
9391

9492
const flattenedNextFrame = nextNDFrame.flatten().data;
9593
return flattenedNextFrame;

0 commit comments

Comments
 (0)