@@ -211,44 +211,62 @@ module Layer {
211211 }
212212 }
213213
214- import CTypes;
214+
215215 class TorchModule : Module(?) {
216- var modulePath: string ;
217- var moduleHandle: Bridge.bridge_pt_model_t;
216+ var torchModuleHandle: Bridge.torchModuleHandle;
218217
219- proc init (type eltType, modulePath : string ) {
218+ proc init (type eltType, torchModuleHandle : Bridge.torchModuleHandle ) {
220219 super .init(eltType);
221- this .modulePath = modulePath;
222- const fpPtr: CTypes.c_ptr(uint (8 )) = CTypes.c_ptrToConst(modulePath) : CTypes.c_ptr(uint (8 ));
223- this .moduleHandle = Bridge.load_model(fpPtr);
220+ this .torchModuleHandle = torchModuleHandle;
224221 init this ;
225222 this .moduleName = " TorchModule" ;
226223 }
227224
228- proc init (modulePath : string ) do
229- this .init(defaultEltType,modulePath );
225+ proc init (torchModuleHandle : Bridge.torchModuleHandle ) do
226+ this .init(defaultEltType,torchModuleHandle );
230227
231- override proc forward(input: dynamicTensor(eltType)): dynamicTensor(eltType) {
232- const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
233- const btOutput = Bridge.model_forward(this .moduleHandle, btInput);
234- return btOutput : dynamicTensor(eltType);
228+ override proc forward(input: dynamicTensor(eltType)):
229+ dynamicTensor(eltType) {
230+ const th = input : Bridge.tensorHandle(eltType);
231+ const thOutput = Bridge.modelForward(this .torchModuleHandle,th);
232+ return thOutput : dynamicTensor(eltType);
235233 }
236234 }
237235
238- class StyleTransfer : TorchModule(?) {
236+ class LoadedTorchModel : TorchModule(?) {
237+ var modelPath: string ;
238+
239+ proc init (type eltType, modelPath: string ) {
240+ var torchModuleHandle = Bridge.loadModel(modelPath);
241+ super .init(eltType,torchModuleHandle);
242+ this .modelPath = modelPath;
243+ init this ;
244+ this .moduleName = " LoadedTorchModel" ;
245+ }
239246
240- proc init (type eltType, modulePath: string ) do
241- super .init(eltType,modulePath);
247+ proc init (modelPath: string ) do
248+ this .init(defaultEltType,modelPath);
249+ }
250+
251+ class StyleTransfer : LoadedTorchModel(?) {
252+ proc init (type eltType, modelPath: string ) {
253+ super .init(eltType,modelPath);
254+ init this ;
255+ this .moduleName = " StyleTransferLoadedTorchModel" ;
256+ }
242257
243- proc init (modulePath : string ) do
244- super .init(defaultEltType,modulePath );
258+ proc init (modelPath : string ) do
259+ this .init(defaultEltType,modelPath );
245260
246261 override proc forward(input: dynamicTensor(eltType)):
247262 dynamicTensor(eltType) {
248- const btInput: Bridge.tensorHandle(eltType) = input : Bridge.tensorHandle(eltType);
249- const btOutput = Bridge.model_forward_style_transfer(this .moduleHandle, btInput);
250- return btOutput : dynamicTensor(eltType);
263+ const th = input : Bridge.tensorHandle(eltType);
264+ const thOutput
265+ = Bridge.modelForwardStyleTransfer(this .torchModuleHandle,th);
266+ return thOutput : dynamicTensor(eltType);
251267 }
252268 }
253269
270+
271+
254272}
0 commit comments