diff --git a/src/gnn_tracking/models/interaction_network.py b/src/gnn_tracking/models/interaction_network.py index 192d10fc..40896430 100644 --- a/src/gnn_tracking/models/interaction_network.py +++ b/src/gnn_tracking/models/interaction_network.py @@ -1,6 +1,7 @@ import torch from pytorch_lightning.core.mixins import HyperparametersMixin from torch import Tensor as T +from torch.jit import script as jit from torch_geometric.nn import MessagePassing from gnn_tracking.models.mlp import MLP @@ -33,15 +34,19 @@ def __init__( """ super().__init__(aggr=aggr, flow="source_to_target") self.save_hyperparameters() - self.relational_model = MLP( - 2 * node_indim + edge_indim, - edge_outdim, - edge_hidden_dim, + self.relational_model = jit( + MLP( + 2 * node_indim + edge_indim, + edge_outdim, + edge_hidden_dim, + ) ) - self.object_model = MLP( - node_indim + edge_outdim, - node_outdim, - node_hidden_dim, + self.object_model = jit( + MLP( + node_indim + edge_outdim, + node_outdim, + node_hidden_dim, + ) ) self._e_tilde: T | None = None