Skip to content

Commit 8b7ee82

Browse files
Merge branch 'Iainmon:main' into main
2 parents 6bdda37 + 26bf18a commit 8b7ee82

4 files changed

Lines changed: 55 additions & 30 deletions

File tree

demos/video/chapel-webcam/readme.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ make bridge_objs
66
cd demos/video/chapel-webcam
77
make cleanall && make clean && make libsmol && make main
88
./main --modelPath sobel.pt
9+
10+
./main --chaiImpl=true --accelScale=0.45 --modelPath sobel.pt
911
```

demos/video/chapel-webcam/smol.chpl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,17 @@ const startTime = getTime();
7272
// ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt // This is the one
7373
// ../style-transfer/models/exports/cpu/mosaic_float16.pt
7474
config const modelPath: string = "../style-transfer/models/exports/cpu/mosaic_float16.pt";
75-
var model : Bridge.bridge_pt_model_t;
75+
var model : Bridge.torchModuleHandle;
7676

7777
var modelLayer : shared TorchModule(real(32))?;
7878

7979

8080
use CTypes;
8181

8282
export proc globalLoadModel() {
83-
const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
84-
model = Bridge.load_model(fpPtr);
83+
model = Bridge.loadModel(modelPath);
8584
if modelPath == "sobel.pt" then
86-
modelLayer = new shared TorchModule(modelPath);
85+
modelLayer = new shared LoadedTorchModel(modelPath);
8786
else
8887
modelLayer = new shared StyleTransfer(modelPath);
8988
}
@@ -173,9 +172,9 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
173172
var btFrame: Bridge.bridge_tensor_t = Bridge.createBridgeTensorWithShape(frame,shape);
174173
var bt: Bridge.bridge_tensor_t;
175174
if modelPath == "sobel.pt" then
176-
bt = Bridge.model_forward(model,btFrame);
175+
bt = Bridge.modelForward(model,btFrame);
177176
else
178-
bt = Bridge.model_forward_style_transfer(model,btFrame);
177+
bt = Bridge.modelForwardStyleTransfer(model,btFrame);
179178

180179
const nextNDFrame = bt : ndarray(3, real(32));
181180
const flattenedNextFrame = nextNDFrame.flatten().data;

lib/Bridge.chpl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ module Bridge {
4444
}
4545
}
4646

47+
proc torchModuleHandle type do return bridge_pt_model_t;
48+
4749
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
4850

4951
// extern proc load_tensor_from_file(file_path: c_ptrConst(u_char)): bridge_tensor_t; // Working
@@ -61,13 +63,17 @@ module Bridge {
6163
model_path: string_t,
6264
in input: bridge_tensor_t): bridge_tensor_t;
6365

64-
extern proc load_model(model_path: string_t): bridge_pt_model_t;
66+
extern "load_model" proc loadModelC(model_path: string_t): bridge_pt_model_t;
67+
proc loadModel(modelPath: string): torchModuleHandle {
68+
const model_path: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
69+
return loadModelC(model_path);
70+
}
6571

66-
extern proc model_forward(
72+
extern "model_forward" proc modelForward(
6773
in model: bridge_pt_model_t,
6874
in input: bridge_tensor_t): bridge_tensor_t;
6975

70-
extern proc model_forward_style_transfer(
76+
extern "model_forward_style_transfer" proc modelForwardStyleTransfer(
7177
in model: bridge_pt_model_t,
7278
in input: bridge_tensor_t): bridge_tensor_t;
7379

lib/Layer.chpl

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -536,44 +536,62 @@ module Layer {
536536
}
537537
}
538538

539-
import CTypes;
539+
540540
class TorchModule : Module(?) {
541-
var modulePath: string;
542-
var moduleHandle: Bridge.bridge_pt_model_t;
541+
var torchModuleHandle: Bridge.torchModuleHandle;
543542

544-
proc init(type eltType, modulePath: string) {
543+
proc init(type eltType, torchModuleHandle: Bridge.torchModuleHandle) {
545544
super.init(eltType);
546-
this.modulePath = modulePath;
547-
const fpPtr: CTypes.c_ptr(uint(8)) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint(8));
548-
this.moduleHandle = Bridge.load_model(fpPtr);
545+
this.torchModuleHandle = torchModuleHandle;
549546
init this;
550547
this.moduleName = "TorchModule";
551548
}
552549

553-
proc init(modulePath: string) do
554-
this.init(defaultEltType,modulePath);
550+
proc init(torchModuleHandle: Bridge.torchModuleHandle) do
551+
this.init(defaultEltType,torchModuleHandle);
555552

556-
override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) {
557-
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
558-
const btOutput = Bridge.model_forward(this.moduleHandle, btInput);
559-
return btOutput : dynamicTensor(eltType);
553+
override proc forward(input: dynamicTensor(eltType)):
554+
dynamicTensor(eltType) {
555+
const th = input : Bridge.tensorHandle(eltType);
556+
const thOutput = Bridge.modelForward(this.torchModuleHandle,th);
557+
return thOutput : dynamicTensor(eltType);
560558
}
561559
}
562560

563-
class StyleTransfer : TorchModule(?) {
561+
class LoadedTorchModel : TorchModule(?) {
562+
var modelPath: string;
563+
564+
proc init(type eltType, modelPath: string) {
565+
var torchModuleHandle = Bridge.loadModel(modelPath);
566+
super.init(eltType,torchModuleHandle);
567+
this.modelPath = modelPath;
568+
init this;
569+
this.moduleName = "LoadedTorchModel";
570+
}
564571

565-
proc init(type eltType, modulePath: string) do
566-
super.init(eltType,modulePath);
572+
proc init(modelPath: string) do
573+
this.init(defaultEltType,modelPath);
574+
}
575+
576+
class StyleTransfer : LoadedTorchModel(?) {
577+
proc init(type eltType, modelPath: string) {
578+
super.init(eltType,modelPath);
579+
init this;
580+
this.moduleName = "StyleTransferLoadedTorchModel";
581+
}
567582

568-
proc init(modulePath: string) do
569-
super.init(defaultEltType,modulePath);
583+
proc init(modelPath: string) do
584+
this.init(defaultEltType,modelPath);
570585

571586
override proc forward(input: dynamicTensor(eltType)):
572587
dynamicTensor(eltType) {
573-
const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
574-
const btOutput = Bridge.model_forward_style_transfer(this.moduleHandle, btInput);
575-
return btOutput : dynamicTensor(eltType);
588+
const th = input : Bridge.tensorHandle(eltType);
589+
const thOutput
590+
= Bridge.modelForwardStyleTransfer(this.torchModuleHandle,th);
591+
return thOutput : dynamicTensor(eltType);
576592
}
577593
}
578594

595+
596+
579597
}

0 commit comments

Comments
 (0)