Skip to content

Commit 3e3e473

Browse files
author
valhassan
committed
Enhance ScriptModel to support auxiliary head and output extraction from dictionary format
1 parent 36028d9 commit 3e3e473

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

geo_deep_learning/tools/script_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
from torch.nn import functional as F
44

55
class ScriptModel(nn.Module):
6-
def __init__(self, model, data_module, num_classes, from_logits = True):
6+
def __init__(self, model, data_module, num_classes, from_logits = True, aux_head = False):
77
super().__init__()
88
self.model = model
99
self.data_module = data_module
1010
self.num_classes = num_classes
1111
self.from_logits = from_logits
12-
12+
self.aux_head = aux_head
1313
def forward(self, x):
1414
sample = {"image": x}
1515
preprocessed = self.data_module._model_script_preprocess(sample)
1616
output = self.model(preprocessed["image"])
17+
if isinstance(output, dict):
18+
output = output['out']
1719
if self.from_logits:
1820
if self.num_classes > 1:
1921
output = F.softmax(output, dim=1)

0 commit comments

Comments
 (0)