Skip to content

Commit 2c4df75

Browse files
author
Robin Manhaeve
committed
Fix batching network inputs of mixed type
1 parent 04fe7a1 commit 2c4df75

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

src/deepproblog/network.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ class Network(object):
1717
"""Wraps a PyTorch neural network for use with DeepProblog"""
1818

1919
def __init__(
20-
self,
21-
network_module: torch.nn.Module,
22-
name: str,
23-
optimizer: Optional[torch.optim.Optimizer] = None,
24-
scheduler=None,
25-
k: Optional[int] = None,
26-
batching: bool = False,
20+
self,
21+
network_module: torch.nn.Module,
22+
name: str,
23+
optimizer: Optional[torch.optim.Optimizer] = None,
24+
scheduler=None,
25+
k: Optional[int] = None,
26+
batching: bool = False,
2727
):
2828
"""Create a Network object
2929
@@ -121,13 +121,17 @@ def __call__(self, to_evaluate: list) -> list:
121121
:return:
122122
"""
123123
if self.batching:
124-
batched_inputs: List[torch.Tensor] = [
125-
self.function(*e)[0] for e in to_evaluate
126-
]
127-
stacked_inputs = torch.stack(batched_inputs)
128-
if self.is_cuda:
129-
stacked_inputs = stacked_inputs.cuda(device=self.device)
130-
evaluated = self.network_module(stacked_inputs)
124+
inputs = (self.function(*e) for e in to_evaluate)
125+
stacked_inputs = list()
126+
for inputs in zip(*inputs):
127+
try:
128+
inputs = torch.stack(inputs)
129+
if self.is_cuda:
130+
inputs.cuda(device=self.device)
131+
except TypeError:
132+
inputs = list(inputs)
133+
stacked_inputs.append(inputs)
134+
evaluated = self.network_module(*stacked_inputs)
131135
else:
132136
evaluated = [self.network_module(*self.function(*e)) for e in to_evaluate]
133137
return evaluated
@@ -169,7 +173,6 @@ def get_hyperparameters(self):
169173
}
170174
return parameters
171175

172-
173176
# class NetworkEvaluation(object):
174177
# """
175178
# An object that keeps track of which inputs the neural networks need to be evaluated on.

src/deepproblog/tests/test_neural_predicate.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y).
1313
nn(dummy2,[X]) :: net2(X).
1414
nn(dummy3,[X],Y) :: net3(X,Y).
15-
nn(dummy4,[X,Y],Z,[a,b]) :: net4(X,Y,Z).
15+
nn(dummy4,[X,Y]) :: net4(X,Y).
1616
1717
test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2).
1818
test2(X1,X2) :- net2(X1), net2(X2).
@@ -28,9 +28,19 @@
2828
dummy_values3 = {Term("i1"): [1.0, 2.0, 3.0, 4.0], Term("i2"): [-1.0, 0.0, 1.0]}
2929
dummy_net3 = Network(DummyNet(dummy_values3), "dummy3")
3030

31-
dummy_net4 = Network(DummyTensorNet(batching=True), "dummy4", batching=True)
3231

33-
tensors = {(Constant(0),): torch.Tensor([0.2]), (Constant(1),): torch.Tensor([0.8])}
32+
dummy_tensors = {(Term("a"),): torch.Tensor([0.1, 0.2, 0.3, 0.4]), (Term("b"),): torch.Tensor([0.25, 0.25, 0.25, 0.25])}
33+
34+
35+
class IndexNet(torch.nn.Module):
36+
37+
def forward(self, t, index):
38+
# index = int(index)
39+
index = torch.LongTensor([int(i) for i in index])
40+
return t.index_select(dim=1, index=index)
41+
42+
43+
dummy_net4 = Network(IndexNet(), "dummy4", batching=True)
3444

3545

3646
@pytest.fixture(
@@ -53,7 +63,7 @@ def model(request) -> Model:
5363
model = Model(program, [dummy_net1, dummy_net2, dummy_net3, dummy_net4], load=False)
5464
engine = request.param["engine_factory"](model)
5565
model.set_engine(engine, cache=request.param["cache"])
56-
model.add_tensor_source('dummy', tensors)
66+
model.add_tensor_source('dummy', dummy_tensors)
5767
return model
5868

5969

@@ -108,13 +118,12 @@ def test_det_network_substitution(model: Model):
108118
assert all(r1.detach().numpy() == [1.0, 2.0, 3.0, 4.0])
109119
assert all(r2.detach().numpy() == [-1.0, 0.0, 1.0])
110120

111-
def test_double_input(model: Model):
112-
terms = lambda x: Term("net4",
113-
Term("tensor",Term("dummy", Constant(0))),
114-
Term("tensor",Term("dummy", Constant(1))),
115-
x)
116-
results = model.solve([Query(terms(Var("X")))])
117-
r1 = float(results[0].result[terms(Term("a"))])
118-
r2 = float(results[0].result[terms(Term("b"))])
121+
def test_multi_input_network(model: Model):
122+
dummy_tensor = lambda x: Term("tensor", Term("dummy", x))
123+
q1 = Query(Term("net4", dummy_tensor(Term("a")), Constant(1)))
124+
q2 = Query(Term("net4", dummy_tensor(Term("b")), Constant(2)))
125+
results = model.solve([q1, q2])
126+
r1 = float(results[0].result[q1.query])
127+
r2 = float(results[1].result[q2.query])
119128
assert pytest.approx(0.2) == r1
120-
assert pytest.approx(0.8) == r2
129+
assert pytest.approx(0.25) == r2

0 commit comments

Comments
 (0)