Skip to content

Commit 26bf18a

Browse files
authored
Enrich layer module with pytorch constructs. (#74)
2 parents 92b241e + 2da0c42 commit 26bf18a

4 files changed

Lines changed: 55 additions & 30 deletions

File tree

demos/video/chapel-webcam/readme.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ make bridge_objs
66
cd demos/video/chapel-webcam
77
make cleanall && make clean && make libsmol && make main
88
./main --modelPath sobel.pt
9+
10+
./main --chaiImpl=true --accelScale=0.45 --modelPath sobel.pt
911
```

demos/video/chapel-webcam/smol.chpl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,17 @@ const startTime = getTime();
7272
// ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt // This is the one
7373
// ../style-transfer/models/exports/cpu/mosaic_float16.pt
7474
config const modelPath: string = "../style-transfer/models/exports/cpu/mosaic_float16.pt";
75-
var model : Bridge.bridge_pt_model_t;
75+
var model : Bridge.torchModuleHandle;
7676

7777
var modelLayer : shared TorchModule(real(32))?;
7878

7979

8080
use CTypes;
8181

8282
export proc globalLoadModel() {
83-
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
84-
model = Bridge.load_model(fpPtr);
83+
model = Bridge.loadModel(modelPath);
8584
if modelPath == "sobel.pt" then
86-
modelLayer = new shared TorchModule(modelPath);
85+
modelLayer = new shared LoadedTorchModel(modelPath);
8786
else
8887
modelLayer = new shared StyleTransfer(modelPath);
8988
}
@@ -173,9 +172,9 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
173172
var btFrame: Bridge.bridge_tensor_t = Bridge.createBridgeTensorWithShape(frame,shape);
174173
var bt: Bridge.bridge_tensor_t;
175174
if modelPath == "sobel.pt" then
176-
bt = Bridge.model_forward(model,btFrame);
175+
bt = Bridge.modelForward(model,btFrame);
177176
else
178-
bt = Bridge.model_forward_style_transfer(model,btFrame);
177+
bt = Bridge.modelForwardStyleTransfer(model,btFrame);
179178

180179
const nextNDFrame = bt : ndarray(3, real(32));
181180
const flattenedNextFrame = nextNDFrame.flatten().data;

lib/Bridge.chpl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ module Bridge {
4444
}
4545
}
4646

47+
proc torchModuleHandle type do return bridge_pt_model_t;
48+
4749
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
4850

4951
// extern proc load_tensor_from_file(file_path: c_ptrConst(u_char)): bridge_tensor_t; // Working
@@ -61,13 +63,17 @@ module Bridge {
6163
model_path: string_t,
6264
in input: bridge_tensor_t): bridge_tensor_t;
6365

64-
extern proc load_model(model_path: string_t): bridge_pt_model_t;
66+
extern "load_model" proc loadModelC(model_path: string_t): bridge_pt_model_t;
67+
proc loadModel(modelPath: string): torchModuleHandle {
68+
const model_path: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
69+
return loadModelC(model_path);
70+
}
6571

66-
extern proc model_forward(
72+
extern "model_forward" proc modelForward(
6773
in model: bridge_pt_model_t,
6874
in input: bridge_tensor_t): bridge_tensor_t;
6975

70-
extern proc model_forward_style_transfer(
76+
extern "model_forward_style_transfer" proc modelForwardStyleTransfer(
7177
in model: bridge_pt_model_t,
7278
in input: bridge_tensor_t): bridge_tensor_t;
7379

lib/Layer.chpl

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -211,44 +211,62 @@ module Layer {
211211
}
212212
}
213213

214-
import CTypes;
214+
215215
class TorchModule : Module(?) {
216-
var modulePath: string;
217-
var moduleHandle: Bridge.bridge_pt_model_t;
216+
var torchModuleHandle: Bridge.torchModuleHandle;
218217

219-
proc init(type eltType, modulePath: string) {
218+
proc init(type eltType, torchModuleHandle: Bridge.torchModuleHandle) {
220219
super.init(eltType);
221-
this.modulePath = modulePath;
222-
const fpPtr: CTypes.c_ptr(uint(8)) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint(8));
223-
this.moduleHandle = Bridge.load_model(fpPtr);
220+
this.torchModuleHandle = torchModuleHandle;
224221
init this;
225222
this.moduleName = "TorchModule";
226223
}
227224

228-
proc init(modulePath: string) do
229-
this.init(defaultEltType,modulePath);
225+
proc init(torchModuleHandle: Bridge.torchModuleHandle) do
226+
this.init(defaultEltType,torchModuleHandle);
230227

231-
override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) {
232-
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
233-
const btOutput = Bridge.model_forward(this.moduleHandle, btInput);
234-
return btOutput : dynamicTensor(eltType);
228+
override proc forward(input: dynamicTensor(eltType)):
229+
dynamicTensor(eltType) {
230+
const th = input : Bridge.tensorHandle(eltType);
231+
const thOutput = Bridge.modelForward(this.torchModuleHandle,th);
232+
return thOutput : dynamicTensor(eltType);
235233
}
236234
}
237235

238-
class StyleTransfer : TorchModule(?) {
236+
class LoadedTorchModel : TorchModule(?) {
237+
var modelPath: string;
238+
239+
proc init(type eltType, modelPath: string) {
240+
var torchModuleHandle = Bridge.loadModel(modelPath);
241+
super.init(eltType,torchModuleHandle);
242+
this.modelPath = modelPath;
243+
init this;
244+
this.moduleName = "LoadedTorchModel";
245+
}
239246

240-
proc init(type eltType, modulePath: string) do
241-
super.init(eltType,modulePath);
247+
proc init(modelPath: string) do
248+
this.init(defaultEltType,modelPath);
249+
}
250+
251+
class StyleTransfer : LoadedTorchModel(?) {
252+
proc init(type eltType, modelPath: string) {
253+
super.init(eltType,modelPath);
254+
init this;
255+
this.moduleName = "StyleTransferLoadedTorchModel";
256+
}
242257

243-
proc init(modulePath: string) do
244-
super.init(defaultEltType,modulePath);
258+
proc init(modelPath: string) do
259+
this.init(defaultEltType,modelPath);
245260

246261
override proc forward(input: dynamicTensor(eltType)):
247262
dynamicTensor(eltType) {
248-
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
249-
const btOutput = Bridge.model_forward_style_transfer(this.moduleHandle, btInput);
250-
return btOutput : dynamicTensor(eltType);
263+
const th = input : Bridge.tensorHandle(eltType);
264+
const thOutput
265+
= Bridge.modelForwardStyleTransfer(this.torchModuleHandle,th);
266+
return thOutput : dynamicTensor(eltType);
251267
}
252268
}
253269

270+
271+
254272
}

0 commit comments

Comments
 (0)