1212nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y).
1313nn(dummy2,[X]) :: net2(X).
1414nn(dummy3,[X],Y) :: net3(X,Y).
15- nn(dummy4,[X,Y],Z,[a,b] ) :: net4(X,Y,Z ).
15+ nn(dummy4,[X,Y]) :: net4(X,Y).
1616
1717test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2).
1818test2(X1,X2) :- net2(X1), net2(X2).
2828dummy_values3 = {Term ("i1" ): [1.0 , 2.0 , 3.0 , 4.0 ], Term ("i2" ): [- 1.0 , 0.0 , 1.0 ]}
2929dummy_net3 = Network (DummyNet (dummy_values3 ), "dummy3" )
3030
31- dummy_net4 = Network (DummyTensorNet (batching = True ), "dummy4" , batching = True )
3231
33- tensors = {(Constant (0 ),): torch .Tensor ([0.2 ]), (Constant (1 ),): torch .Tensor ([0.8 ])}
32+ dummy_tensors = {(Term ("a" ),): torch .Tensor ([0.1 , 0.2 , 0.3 , 0.4 ]), (Term ("b" ),): torch .Tensor ([0.25 , 0.25 , 0.25 , 0.25 ])}
33+
34+
35+ class IndexNet (torch .nn .Module ):
36+
37+ def forward (self , t , index ):
38+ # index = int(index)
39+ index = torch .LongTensor ([int (i ) for i in index ])
40+ return t .index_select (dim = 1 , index = index )
41+
42+
43+ dummy_net4 = Network (IndexNet (), "dummy4" , batching = True )
3444
3545
3646@pytest .fixture (
@@ -53,7 +63,7 @@ def model(request) -> Model:
5363 model = Model (program , [dummy_net1 , dummy_net2 , dummy_net3 , dummy_net4 ], load = False )
5464 engine = request .param ["engine_factory" ](model )
5565 model .set_engine (engine , cache = request .param ["cache" ])
56- model .add_tensor_source ('dummy' , tensors )
66+ model .add_tensor_source ('dummy' , dummy_tensors )
5767 return model
5868
5969
@@ -108,13 +118,12 @@ def test_det_network_substitution(model: Model):
108118 assert all (r1 .detach ().numpy () == [1.0 , 2.0 , 3.0 , 4.0 ])
109119 assert all (r2 .detach ().numpy () == [- 1.0 , 0.0 , 1.0 ])
110120
111- def test_double_input (model : Model ):
112- terms = lambda x : Term ("net4" ,
113- Term ("tensor" ,Term ("dummy" , Constant (0 ))),
114- Term ("tensor" ,Term ("dummy" , Constant (1 ))),
115- x )
116- results = model .solve ([Query (terms (Var ("X" )))])
117- r1 = float (results [0 ].result [terms (Term ("a" ))])
118- r2 = float (results [0 ].result [terms (Term ("b" ))])
121+ def test_multi_input_network (model : Model ):
122+ dummy_tensor = lambda x : Term ("tensor" , Term ("dummy" , x ))
123+ q1 = Query (Term ("net4" , dummy_tensor (Term ("a" )), Constant (1 )))
124+ q2 = Query (Term ("net4" , dummy_tensor (Term ("b" )), Constant (2 )))
125+ results = model .solve ([q1 , q2 ])
126+ r1 = float (results [0 ].result [q1 .query ])
127+ r2 = float (results [1 ].result [q2 .query ])
119128 assert pytest .approx (0.2 ) == r1
120- assert pytest .approx (0.8 ) == r2
129+ assert pytest .approx (0.25 ) == r2
0 commit comments