Skip to content

Commit c3f1c59

Browse files
committed
batch size parameter explicilty added
1 parent ff99627 commit c3f1c59

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

owlapy/embedding_based_reasoner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ class EBR(AbstractOWLReasoner):
3434
STR_IRI_DATA_PROPERTY = "http://www.w3.org/2002/07/owl#DatatypeProperty"
3535
STR_IRI_SUBPROPERTY = "http://www.w3.org/2000/01/rdf-schema#subPropertyOf"
3636

37-
def __init__(self, ontology: NeuralOntology, gamma: float = 0.5, device: str = "gpu"):
37+
def __init__(self, ontology: NeuralOntology, gamma: float = 0.5, batch_size: int = 1024, device: str = "gpu"):
3838
super().__init__(ontology)
3939
self.gamma = gamma
4040
self.ontology = ontology
4141
self.model = ontology.model
42+
self.batch_size = batch_size
4243
if device == "gpu" and torch.cuda.is_available():
4344
self.model.to("cuda")
4445
print("EBR inference on GPU")
@@ -66,7 +67,7 @@ def predict(self, h: List[str] = None, r: List[str] = None, t: List[str] = None)
6667
else:
6768
topk = len(self.model.entity_to_idx)
6869

69-
return [ (top_entity, score) for row in self.model.predict_topk(h=h, r=r, t=t, topk=topk) for top_entity, score in row if score >= self.gamma and is_valid_entity(top_entity)]
70+
return [ (top_entity, score) for row in self.model.predict_topk(h=h, r=r, t=t, topk=topk, batch_size=self.batch_size) for top_entity, score in row if score >= self.gamma and is_valid_entity(top_entity)]
7071

7172
def predict_individuals_of_owl_class(self, owl_class: OWLClass) -> List[OWLNamedIndividual]:
7273
top_entities=set()

tests/test_embedding_based_reasoner_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def setup_class(cls):
5252
path_neural_embedding=cls.path_kg,
5353
train_if_not_exists=True
5454
)
55-
cls.neural_owl_reasoner = EBR(ontology=neural_ontology, gamma=cls.gamma)
55+
cls.neural_owl_reasoner = EBR(ontology=neural_ontology, gamma=cls.gamma, batch_size=2, device="cpu")
5656

5757
# Generate test concepts
5858
cls._generate_test_concepts()

0 commit comments

Comments
 (0)