1+ // import Utilities as utils;
2+ // import Bridge;
3+
4+ // use NDArray;
5+ // use Layer;
6+
7+ use Tensor;
8+ use Layer;
19import Utilities as utils;
210
3- use NDArray;
4- import Bridge;
511
612export proc square(x: int ): int {
713 writeln (x, " * " , x, " = " , x * x);
@@ -38,11 +44,18 @@ const startTime = getTime();
3844config const modelPath: string = " ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt" ;
3945var model : Bridge.bridge_pt_model_t;
4046
47+ var modelLayer : shared TorchModule(real (32 ))?;
48+
49+
4150use CTypes;
4251
4352export proc globalLoadModel() {
4453 const fpPtr: c_ptr(uint (8 )) = c_ptrToConst(modelPath) : c_ptr(uint (8 ));
4554 model = Bridge.load_model(fpPtr);
55+ if modelPath == " sobel.pt" then
56+ modelLayer = new shared TorchModule(modelPath);
57+ else
58+ modelLayer = new shared StyleTransfer(modelPath);
4659
4760 // const fpPtr: c_ptr(uint(8)) = c_ptrToConst(modelPath) : c_ptr(uint(8));
4861 // var model = Bridge.load_model(fpPtr);
@@ -58,7 +71,16 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
5871 writeln (" FPS: " , 1.0 / dt);
5972 const shape = (height,width,channels);
6073 const frameDom = utils.domainFromShape((...shape));
74+ // const frameArr = reshape(frame,frameDom);
75+ const dtInput = (new dynamicTensor(frame)).reshape((...shape));
76+ const dtOutput = modelLayer!.forward(dtInput);
77+ const outputFrame = dtOutput.flatten().toArray(1 );
78+
6179
80+ lastFrame = getTime();
81+ return outputFrame;
82+
83+ /*
6284 var btFrame: Bridge.bridge_tensor_t = Bridge.createBridgeTensorWithShape(frame,shape);
6385 var bt: Bridge.bridge_tensor_t;
6486 if modelPath == "sobel.pt" then
@@ -72,6 +94,8 @@ export proc getNewFrame(ref frame: [] real(32),height: int, width: int,channels:
7294 const flattenedNextFrame = nextNDFrame.flatten().data;
7395 lastFrame = getTime();
7496 return flattenedNextFrame;
97+ */
98+
7599
76100 // forall i in 0..<frame.size {
77101 // const idx = utils.indexAt(i,(...shape));
0 commit comments