Skip to content

Commit ad5cb4a

Browse files
committed
Model loading working in chapel-webcam.
1 parent 7ba7b3a commit ad5cb4a

6 files changed

Lines changed: 38 additions & 9 deletions

File tree

bridge/include/bridge.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ typedef struct bridge_tensor_t {
2626

2727

2828
typedef struct bridge_pt_model_t {
29-
uint64_t pt_module;
29+
void* pt_module;
3030
} bridge_pt_model_t;
3131

3232
typedef struct test_struct_t {

bridge/lib/bridge.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
132132

133133

134134
extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
135+
std::cout << "Begin loading model from path: " << model_path << std::endl;
136+
std::cout.flush();
135137
std::string mp(reinterpret_cast<const char*>(model_path));
136138
std::cout << "Loading model from path: " << mp << std::endl;
137139
std::cout.flush();
@@ -141,20 +143,28 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
141143
try {
142144
*pt_module = torch::jit::load(mp);
143145
std::cout << "Model loaded successfully!" << std::endl;
144-
model_wrapper.pt_module = (uint64_t) pt_module;
146+
std::cout.flush();
147+
model_wrapper.pt_module = pt_module;
145148
} catch (const c10::Error& e) {
146149
std::cerr << "error loading the model\n" << e.msg();
150+
std::cout << "error loading the model\n" << e.msg();
151+
std::cout.flush();
152+
std::cerr.flush();
147153
}
148154

155+
std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
156+
std::cout.flush();
157+
149158
return model_wrapper;
150159
}
151160

152161
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
153162
auto t_input = bridge_to_torch(input);
154163
std::vector<torch::jit::IValue> inputs;
155164
inputs.push_back(t_input);
156-
torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
157-
auto output = pt_module->forward(inputs).toTensor();
165+
// torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
166+
// auto output = pt_module->forward(inputs).toTensor();
167+
auto output = t_input;
158168
return torch_to_bridge(output);
159169
}
160170

demos/video/chapel-webcam/lib/smol.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ void chpl__init_smol(int64_t _ln,
1919
int32_t _fn);
2020
int64_t square(int64_t x);
2121
void printArray(chpl_external_array * a);
22+
void globalLoadModel(void);
2223
chpl_external_array getNewFrame(chpl_external_array * frame,
2324
int64_t height,
2425
int64_t width,

demos/video/chapel-webcam/main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ int mirror() {
138138

139139
int main(int argc, char* argv[]) {
140140
chpl_library_init(argc, argv);
141+
142+
chpl__init_Bridge(0, 0);
143+
chpl__init_smol(0, 0);
141144

142145
square(3);
143146

@@ -153,6 +156,8 @@ int main(int argc, char* argv[]) {
153156
printArray(&matrix_ptr);
154157
chpl_free_external_array(matrix_ptr);
155158

159+
globalLoadModel();
160+
156161
int code = mirror();
157162

158163

demos/video/chapel-webcam/smol.chpl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ proc getTime() {
3333

3434
const startTime = getTime();
3535

36-
config const modelPath: string;
36+
const modelPath: string = "../style-transfer/models/exports/cpu/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
37+
var model : Bridge.bridge_pt_model_t;
3738

3839
use CTypes;
39-
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
40-
const model = Bridge.load_model(fpPtr);
40+
41+
export proc globalLoadModel() {
42+
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
43+
model = Bridge.load_model(fpPtr);
44+
}
4145

4246

4347

@@ -51,7 +55,16 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
5155
var ndframe = new ndarray(real(32),shape);
5256
ndframe.data = reshape(frame,ndframe.domain);
5357

54-
writeln(ndframe.max());
58+
var bt = Bridge.model_forward(model,ndframe : Bridge.tensorHandle(real(32)));
59+
ndframe.loadFromBridgeTensor(bt);
60+
61+
forall i in 0..<frame.size {
62+
const idx = utils.indexAt(i,(...shape));
63+
ref color = frame[i];
64+
color = ndframe.data[idx];
65+
}
66+
return frame;
67+
5568
forall i in 0..<frame.size {
5669
const idx = utils.indexAt(i,(...shape));
5770
const (h,w,c) = idx;

lib/Bridge.chpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ module Bridge {
1919
}
2020

2121
extern record bridge_pt_model_t {
22-
var pt_module: uint(64);
22+
var pt_module: c_ptr(void);
2323
}
2424
extern record test_struct_t {
2525
var field: c_ptr(int(32));

0 commit comments

Comments
 (0)