@@ -536,44 +536,62 @@ module Layer {
536536 }
537537 }
538538
539- import CTypes;
539+
540540 class TorchModule : Module(?) {
541- var modulePath: string ;
542- var moduleHandle: Bridge.bridge_pt_model_t;
541+ var torchModuleHandle: Bridge.torchModuleHandle;
543542
544- proc init (type eltType, modulePath : string ) {
543+ proc init (type eltType, torchModuleHandle : Bridge.torchModuleHandle ) {
545544 super .init(eltType);
546- this .modulePath = modulePath;
547- const fpPtr: CTypes.c_ptr(uint (8 )) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint (8 ));
548- this .moduleHandle = Bridge.load_model(fpPtr);
545+ this .torchModuleHandle = torchModuleHandle;
549546 init this ;
550547 this .moduleName = " TorchModule" ;
551548 }
552549
553- proc init (modulePath : string ) do
554- this .init(defaultEltType,modulePath );
550+ proc init (torchModuleHandle : Bridge.torchModuleHandle ) do
551+ this .init(defaultEltType,torchModuleHandle );
555552
556- override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) {
557- const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
558- const btOutput = Bridge.model_forward(this .moduleHandle, btInput);
559- return btOutput : dynamicTensor(eltType);
553+ override proc forward(input: dynamicTensor(eltType)):
554+ dynamicTensor(eltType) {
555+ const th = input : Bridge.tensorHandle(eltType);
556+ const thOutput = Bridge.modelForward(this .torchModuleHandle,th);
557+ return thOutput : dynamicTensor(eltType);
560558 }
561559 }
562560
563- class StyleTransfer : TorchModule(?) {
561+ class LoadedTorchModel : TorchModule(?) {
562+ var modelPath: string ;
563+
564+ proc init (type eltType, modelPath: string ) {
565+ var torchModuleHandle = Bridge.loadModel(modelPath);
566+ super .init(eltType,torchModuleHandle);
567+ this .modelPath = modelPath;
568+ init this ;
569+ this .moduleName = " LoadedTorchModel" ;
570+ }
564571
565- proc init (type eltType, modulePath: string ) do
566- super .init(eltType,modulePath);
572+ proc init (modelPath: string ) do
573+ this .init(defaultEltType,modelPath);
574+ }
575+
576+ class StyleTransfer : LoadedTorchModel(?) {
577+ proc init (type eltType, modelPath: string ) {
578+ super .init(eltType,modelPath);
579+ init this ;
580+ this .moduleName = " StyleTransferLoadedTorchModel" ;
581+ }
567582
568- proc init (modulePath : string ) do
569- super .init(defaultEltType,modulePath );
583+ proc init (modelPath : string ) do
584+ this .init(defaultEltType,modelPath );
570585
571586 override proc forward(input: dynamicTensor(eltType)):
572587 dynamicTensor(eltType) {
573- const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
574- const btOutput = Bridge.model_forward_style_transfer(this .moduleHandle, btInput);
575- return btOutput : dynamicTensor(eltType);
588+ const th = input : Bridge.tensorHandle(eltType);
589+ const thOutput
590+ = Bridge.modelForwardStyleTransfer(this .torchModuleHandle,th);
591+ return thOutput : dynamicTensor(eltType);
576592 }
577593 }
578594
595+
596+
579597}
0 commit comments