Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions demos/video/chapel-webcam/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ make bridge_objs
cd demos/video/chapel-webcam
make cleanall && make clean && make libsmol && make main
./main --modelPath sobel.pt

./main --chaiImpl=true --accelScale=0.45 --modelPath sobel.pt
```
11 changes: 5 additions & 6 deletions demos/video/chapel-webcam/smol.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,17 @@ const startTime = getTime();
// ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt // This is the one
// ../style-transfer/models/exports/cpu/mosaic_float16.pt
config const modelPath: string = "../style-transfer/models/exports/cpu/mosaic_float16.pt";
var model : Bridge.bridge_pt_model_t;
var model : Bridge.torchModuleHandle;

var modelLayer : shared TorchModule(real(32))?;


use CTypes;

export proc globalLoadModel() {
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
model = Bridge.load_model(fpPtr);
model = Bridge.loadModel(modelPath);
if modelPath == "sobel.pt" then
modelLayer = new shared TorchModule(modelPath);
modelLayer = new shared LoadedTorchModel(modelPath);
else
modelLayer = new shared StyleTransfer(modelPath);
}
Expand Down Expand Up @@ -173,9 +172,9 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
var btFrame: Bridge.bridge_tensor_t = Bridge.createBridgeTensorWithShape(frame,shape);
var bt: Bridge.bridge_tensor_t;
if modelPath == "sobel.pt" then
bt = Bridge.model_forward(model,btFrame);
bt = Bridge.modelForward(model,btFrame);
else
bt = Bridge.model_forward_style_transfer(model,btFrame);
bt = Bridge.modelForwardStyleTransfer(model,btFrame);

const nextNDFrame = bt : ndarray(3, real(32));
const flattenedNextFrame = nextNDFrame.flatten().data;
Expand Down
12 changes: 9 additions & 3 deletions lib/Bridge.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ module Bridge {
}
}

proc torchModuleHandle type do return bridge_pt_model_t;

extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));

// extern proc load_tensor_from_file(file_path: c_ptrConst(u_char)): bridge_tensor_t; // Working
Expand All @@ -61,13 +63,17 @@ module Bridge {
model_path: string_t,
in input: bridge_tensor_t): bridge_tensor_t;

extern proc load_model(model_path: string_t): bridge_pt_model_t;
extern "load_model" proc loadModelC(model_path: string_t): bridge_pt_model_t;
proc loadModel(modelPath: string): torchModuleHandle {
const model_path: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
return loadModelC(model_path);
}

extern proc model_forward(
extern "model_forward" proc modelForward(
in model: bridge_pt_model_t,
in input: bridge_tensor_t): bridge_tensor_t;

extern proc model_forward_style_transfer(
extern "model_forward_style_transfer" proc modelForwardStyleTransfer(
in model: bridge_pt_model_t,
in input: bridge_tensor_t): bridge_tensor_t;

Expand Down
60 changes: 39 additions & 21 deletions lib/Layer.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -211,44 +211,62 @@ module Layer {
}
}

import CTypes;

class TorchModule : Module(?) {
var modulePath: string;
var moduleHandle: Bridge.bridge_pt_model_t;
var torchModuleHandle: Bridge.torchModuleHandle;

proc init(type eltType, modulePath: string) {
proc init(type eltType, torchModuleHandle: Bridge.torchModuleHandle) {
super.init(eltType);
this.modulePath = modulePath;
const fpPtr: CTypes.c_ptr(uint(8)) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint(8));
this.moduleHandle = Bridge.load_model(fpPtr);
this.torchModuleHandle = torchModuleHandle;
init this;
this.moduleName = "TorchModule";
}

proc init(modulePath: string) do
this.init(defaultEltType,modulePath);
proc init(torchModuleHandle: Bridge.torchModuleHandle) do
this.init(defaultEltType,torchModuleHandle);

override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) {
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
const btOutput = Bridge.model_forward(this.moduleHandle, btInput);
return btOutput : dynamicTensor(eltType);
override proc forward(input: dynamicTensor(eltType)):
dynamicTensor(eltType) {
const th = input : Bridge.tensorHandle(eltType);
const thOutput = Bridge.modelForward(this.torchModuleHandle,th);
return thOutput : dynamicTensor(eltType);
}
}

class StyleTransfer : TorchModule(?) {
class LoadedTorchModel : TorchModule(?) {
var modelPath: string;

proc init(type eltType, modelPath: string) {
var torchModuleHandle = Bridge.loadModel(modelPath);
super.init(eltType,torchModuleHandle);
this.modelPath = modelPath;
init this;
this.moduleName = "LoadedTorchModel";
}

proc init(type eltType, modulePath: string) do
super.init(eltType,modulePath);
proc init(modelPath: string) do
this.init(defaultEltType,modelPath);
}

class StyleTransfer : LoadedTorchModel(?) {
proc init(type eltType, modelPath: string) {
super.init(eltType,modelPath);
init this;
this.moduleName = "StyleTransferLoadedTorchModel";
}

proc init(modulePath: string) do
super.init(defaultEltType,modulePath);
proc init(modelPath: string) do
this.init(defaultEltType,modelPath);

override proc forward(input: dynamicTensor(eltType)):
dynamicTensor(eltType) {
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
const btOutput = Bridge.model_forward_style_transfer(this.moduleHandle, btInput);
return btOutput : dynamicTensor(eltType);
const th = input : Bridge.tensorHandle(eltType);
const thOutput
= Bridge.modelForwardStyleTransfer(this.torchModuleHandle,th);
return thOutput : dynamicTensor(eltType);
}
}



}
Loading