diff --git a/neuralTensorNetwork.py b/neuralTensorNetwork.py index d0ac2e4..2a611fd 100644 --- a/neuralTensorNetwork.py +++ b/neuralTensorNetwork.py @@ -216,9 +216,9 @@ def neuralTensorNetworkCost(self, theta, data_batch, flip): """ Get entity vectors for examples of 'i'th relation """ - entity_vectors_e1 = entity_vectors[:, e1.tolist()] - entity_vectors_e2 = entity_vectors[:, e2.tolist()] - entity_vectors_e3 = entity_vectors[:, e3.tolist()] + entity_vectors_e1 = entity_vectors[:, [int(i) for i in e1.tolist()]] + entity_vectors_e2 = entity_vectors[:, [int(i) for i in e2.tolist()]] + entity_vectors_e3 = entity_vectors[:, [int(i) for i in e3.tolist()]] """ Choose entity vectors and lists based on 'flip' """