diff --git a/demos/video/chapel-webcam/readme.md b/demos/video/chapel-webcam/readme.md index f3987157..f4b3af86 100644 --- a/demos/video/chapel-webcam/readme.md +++ b/demos/video/chapel-webcam/readme.md @@ -6,4 +6,6 @@ make bridge_objs cd demos/video/chapel-webcam make cleanall && make clean && make libsmol && make main ./main --modelPath sobel.pt + +./main --chaiImpl=true --accelScale=0.45 --modelPath sobel.pt ``` \ No newline at end of file diff --git a/demos/video/chapel-webcam/smol.chpl b/demos/video/chapel-webcam/smol.chpl index 7841f4e2..f97453f8 100644 --- a/demos/video/chapel-webcam/smol.chpl +++ b/demos/video/chapel-webcam/smol.chpl @@ -72,7 +72,7 @@ const startTime = getTime(); // ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt // This is the one // ../style-transfer/models/exports/cpu/mosaic_float16.pt config const modelPath: string = "../style-transfer/models/exports/cpu/mosaic_float16.pt"; -var model : Bridge.bridge_pt_model_t; +var model : Bridge.torchModuleHandle; var modelLayer : shared TorchModule(real(32))?; @@ -80,10 +80,9 @@ var modelLayer : shared TorchModule(real(32))?; use CTypes; export proc globalLoadModel() { - const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8)); - model = Bridge.load_model(fpPtr); + model = Bridge.loadModel(modelPath); if modelPath == "sobel.pt" then - modelLayer = new shared TorchModule(modelPath); + modelLayer = new shared LoadedTorchModel(modelPath); else modelLayer = new shared StyleTransfer(modelPath); } @@ -173,9 +172,9 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels: var btFrame: Bridge.bridge_tensor_t = Bridge.createBridgeTensorWithShape(frame,shape); var bt: Bridge.bridge_tensor_t; if modelPath == "sobel.pt" then - bt = Bridge.model_forward(model,btFrame); + bt = Bridge.modelForward(model,btFrame); else - bt = Bridge.model_forward_style_transfer(model,btFrame); + bt = Bridge.modelForwardStyleTransfer(model,btFrame); const nextNDFrame = bt : ndarray(3, real(32)); const flattenedNextFrame = nextNDFrame.flatten().data; diff --git a/lib/Bridge.chpl b/lib/Bridge.chpl index fd1e6c20..f0b813c0 100644 --- a/lib/Bridge.chpl +++ b/lib/Bridge.chpl @@ -44,6 +44,8 @@ module Bridge { } } + proc torchModuleHandle type do return bridge_pt_model_t; + extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32)); // extern proc load_tensor_from_file(file_path: c_ptrConst(u_char)): bridge_tensor_t; // Working @@ -61,13 +63,17 @@ module Bridge { model_path: string_t, in input: bridge_tensor_t): bridge_tensor_t; - extern proc load_model(model_path: string_t): bridge_pt_model_t; + extern "load_model" proc loadModelC(model_path: string_t): bridge_pt_model_t; + proc loadModel(modelPath: string): torchModuleHandle { + const model_path: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8)); + return loadModelC(model_path); + } - extern proc model_forward( + extern "model_forward" proc modelForward( in model: bridge_pt_model_t, in input: bridge_tensor_t): bridge_tensor_t; - extern proc model_forward_style_transfer( + extern "model_forward_style_transfer" proc modelForwardStyleTransfer( in model: bridge_pt_model_t, in input: bridge_tensor_t): bridge_tensor_t; diff --git a/lib/Layer.chpl b/lib/Layer.chpl index db5639ac..74df93dc 100644 --- a/lib/Layer.chpl +++ b/lib/Layer.chpl @@ -211,44 +211,62 @@ module Layer { } } - import CTypes; + class TorchModule : Module(?) { - var modulePath: string; - var moduleHandle: Bridge.bridge_pt_model_t; + var torchModuleHandle: Bridge.torchModuleHandle; - proc init(type eltType, modulePath: string) { + proc init(type eltType, torchModuleHandle: Bridge.torchModuleHandle) { super.init(eltType); - this.modulePath = modulePath; - const fpPtr: CTypes.c_ptr(uint(8)) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint(8)); - this.moduleHandle = Bridge.load_model(fpPtr); + this.torchModuleHandle = torchModuleHandle; init this; this.moduleName = "TorchModule"; } - proc init(modulePath: string) do - this.init(defaultEltType,modulePath); + proc init(torchModuleHandle: Bridge.torchModuleHandle) do + this.init(defaultEltType,torchModuleHandle); - override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) { - const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType); - const btOutput = Bridge.model_forward(this.moduleHandle, btInput); - return btOutput : dynamicTensor(eltType); + override proc forward(input: dynamicTensor(eltType)): + dynamicTensor(eltType) { + const th = input : Bridge.tensorHandle(eltType); + const thOutput = Bridge.modelForward(this.torchModuleHandle,th); + return thOutput : dynamicTensor(eltType); } } - class StyleTransfer : TorchModule(?) { + class LoadedTorchModel : TorchModule(?) { + var modelPath: string; + + proc init(type eltType, modelPath: string) { + var torchModuleHandle = Bridge.loadModel(modelPath); + super.init(eltType,torchModuleHandle); + this.modelPath = modelPath; + init this; + this.moduleName = "LoadedTorchModel"; + } - proc init(type eltType, modulePath: string) do - super.init(eltType,modulePath); + proc init(modelPath: string) do + this.init(defaultEltType,modelPath); + } + + class StyleTransfer : LoadedTorchModel(?) { + proc init(type eltType, modelPath: string) { + super.init(eltType,modelPath); + init this; + this.moduleName = "StyleTransferLoadedTorchModel"; + } - proc init(modulePath: string) do - super.init(defaultEltType,modulePath); + proc init(modelPath: string) do + this.init(defaultEltType,modelPath); override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) { - const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType); - const btOutput = Bridge.model_forward_style_transfer(this.moduleHandle, btInput); - return btOutput : dynamicTensor(eltType); + const th = input : Bridge.tensorHandle(eltType); + const thOutput + = Bridge.modelForwardStyleTransfer(this.torchModuleHandle,th); + return thOutput : dynamicTensor(eltType); } } + + } \ No newline at end of file