Skip to content

Commit 9da34e6

Browse files
committed
using torch script
1 parent 13874d9 commit 9da34e6

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,18 @@ def one_epoch(
381381
return last_loss
382382

383383

384-
def get_model(device: Device, state: Optional[StateDict] = None):
384+
def get_model(device: Device, state: Optional[StateDict] = None) -> torch.jit.ScriptModule:
385385
"""
386-
Prepare model.
386+
Prepare script model (JIT).
387387
It creates a model, load into a given device and load weights if a state
388388
was provided.
389389
"""
390390
model = Model().to(device)
391391
if state:
392392
model.load_state_dict(state)
393393

394-
return model
394+
script_model = torch.jit.script(model)
395+
return script_model
395396

396397

397398
def train(

0 commit comments

Comments
 (0)