@@ -65,6 +65,19 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
6565 return torch::from_blob (bt.data , shape, torch::kFloat );
6666}
6767
68+ torch::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
69+ std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
70+ auto shape = torch::IntArrayRef (sizes_vec);
71+ auto t = torch::from_blob (bt.data , shape, torch::kFloat );
72+ if (device != torch::kCPU )
73+ copy = true ;
74+ if (copy)
75+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ true );
76+ else
77+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ false );
78+
79+ }
80+
6881extern " C" float32_t * unsafe (const float32_t * arr) {
6982 return const_cast <float32_t *>(arr);
7083}
@@ -142,7 +155,7 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
142155 try {
143156
144157 auto * module = new torch::jit::Module (torch::jit::load (path));
145- module ->to (torch::kCPU );
158+ module ->to (torch::kMPS ,torch:: kFloat16 , false );
146159 module ->eval ();
147160 std::cout << " Model loaded successfully!" << std::endl;
148161 std::cout.flush ();
@@ -190,20 +203,19 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
190203
191204extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
192205
193- auto tn = bridge_to_torch (input).clone ();
194- auto tn_ = tn.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
206+ auto tn_mps = bridge_to_torch (input,torch::kMPS ,true ,torch::kFloat16 );
207+ // auto tn_mps = tn.to(torch::kMPS,false,true);
208+ auto tn_ = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
195209
196210 std::vector<torch::jit::IValue> ins;
197211 ins.push_back (tn_);
198212
199213 auto * module = static_cast <torch::jit::Module*>(model.pt_module );
200214 auto o = module ->forward (ins).toTensor ();
201215 auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
216+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
217+ return torch_to_bridge (tn_out_cpu);
202218
203- return torch_to_bridge (tn_out);
204-
205-
206- //
207219/*
208220
209221 auto t = bridge_to_torch(input).clone();
0 commit comments