Skip to content

Commit 7ec2d7f

Browse files
committed
created ebr test
1 parent 6fcabee commit 7ec2d7f

1 file changed

Lines changed: 285 additions & 0 deletions

File tree

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""Test script for Embedding Based Reasoner (EBR) retrieval evaluation.
2+
3+
This test is based on the retrieval_eval.py script and evaluates the performance
4+
of the EBR against the symbolic reasoner on various OWL class expressions.
5+
It trains a model on KGs/Family/father.owl and asserts perfect scores.
6+
"""
7+
8+
import pytest
9+
import os
10+
import random
11+
import itertools
12+
import time
13+
from typing import Tuple, Set
14+
from itertools import chain
15+
16+
from owlapy.owl_reasoner import StructuralReasoner
17+
from owlapy.embedding_based_reasoner import EBR
18+
from owlapy.owl_ontology import Ontology
19+
from owlapy.neural_ontology import NeuralOntology
20+
from owlapy.class_expression import (
21+
OWLObjectUnionOf,
22+
OWLObjectIntersectionOf,
23+
OWLObjectSomeValuesFrom,
24+
OWLObjectAllValuesFrom,
25+
OWLObjectMinCardinality,
26+
OWLObjectMaxCardinality,
27+
OWLObjectOneOf,
28+
)
29+
30+
31+
class TestEmbeddingBasedReasonerRetrieval:
32+
"""Test class for EBR retrieval performance evaluation."""
33+
34+
@classmethod
35+
def setup_class(cls):
36+
"""Set up the test class with reasoners and ontology."""
37+
cls.path_kg = "KGs/Family/father.owl"
38+
cls.gamma = 0.9
39+
cls.seed = 42
40+
cls.num_nominals = 10
41+
42+
# Fix random seed for reproducibility
43+
random.seed(cls.seed)
44+
45+
# Initialize symbolic reasoner
46+
assert os.path.isfile(cls.path_kg), f"Ontology file not found: {cls.path_kg}"
47+
ontology = Ontology(ontology_iri=cls.path_kg)
48+
cls.symbolic_kb = StructuralReasoner(ontology)
49+
50+
# Initialize Neural OWL Reasoner (train if not exists)
51+
neural_ontology = NeuralOntology(
52+
path_neural_embedding=cls.path_kg,
53+
train_if_not_exists=True
54+
)
55+
cls.neural_owl_reasoner = EBR(ontology=neural_ontology, gamma=cls.gamma)
56+
57+
# Generate test concepts
58+
cls._generate_test_concepts()
59+
60+
@classmethod
61+
def _generate_test_concepts(cls):
62+
"""Generate various OWL class expressions for testing."""
63+
# Extract object properties
64+
object_properties = set(cls.symbolic_kb.get_root_ontology().object_properties_in_signature())
65+
66+
# Inverse object properties
67+
object_properties_inverse = {prop.get_inverse_property() for prop in object_properties}
68+
69+
# R*: R UNION R⁻
70+
cls.object_properties_and_inverse = object_properties.union(object_properties_inverse)
71+
72+
# Named concepts (NC)
73+
cls.nc = set(cls.symbolic_kb.get_root_ontology().classes_in_signature())
74+
75+
# Negated named concepts (NC⁻)
76+
cls.nnc = {concept.get_object_complement_of() for concept in cls.nc}
77+
78+
# NC*: NC UNION NC⁻
79+
cls.nc_star = cls.nc.union(cls.nnc)
80+
81+
# Generate nominals
82+
individuals = list(cls.symbolic_kb.get_root_ontology().individuals_in_signature())
83+
if len(individuals) > cls.num_nominals:
84+
nominals = set(random.sample(individuals, cls.num_nominals))
85+
else:
86+
nominals = set(individuals)
87+
88+
# Nominal combinations (3-tuples)
89+
cls.nominal_combinations = set(
90+
OWLObjectOneOf(combination)
91+
for combination in itertools.combinations(nominals, min(3, len(nominals)))
92+
)
93+
94+
# Generate concept combinations
95+
cls.unions_nc = cls._concept_reducer(cls.nc, OWLObjectUnionOf)
96+
cls.intersections_nc = cls._concept_reducer(cls.nc, OWLObjectIntersectionOf)
97+
cls.unions_nc_star = cls._concept_reducer(cls.nc_star, OWLObjectUnionOf)
98+
cls.intersections_nc_star = cls._concept_reducer(cls.nc_star, OWLObjectIntersectionOf)
99+
100+
# Existential and universal restrictions
101+
cls.exist_nc_star = cls._concept_reducer_properties(
102+
cls.nc_star, cls.object_properties_and_inverse, OWLObjectSomeValuesFrom
103+
)
104+
cls.for_all_nc_star = cls._concept_reducer_properties(
105+
cls.nc_star, cls.object_properties_and_inverse, OWLObjectAllValuesFrom
106+
)
107+
108+
# Cardinality restrictions
109+
cls.min_cardinality_nc_star = cls._concept_reducer_properties(
110+
cls.nc_star, cls.object_properties_and_inverse, OWLObjectMinCardinality, cardinality=1
111+
)
112+
cls.max_cardinality_nc_star = cls._concept_reducer_properties(
113+
cls.nc_star, cls.object_properties_and_inverse, OWLObjectMaxCardinality, cardinality=1
114+
)
115+
116+
# Existential restrictions with nominals
117+
cls.exist_nominals = cls._concept_reducer_properties(
118+
cls.nominal_combinations, cls.object_properties_and_inverse, OWLObjectSomeValuesFrom
119+
)
120+
121+
@staticmethod
122+
def _concept_reducer(concepts, operator_class):
123+
"""Create all binary combinations of concepts with the given operator."""
124+
return {
125+
operator_class(operands=frozenset([c1, c2]))
126+
for c1 in concepts for c2 in concepts if c1 != c2
127+
}
128+
129+
@staticmethod
130+
def _concept_reducer_properties(concepts, properties, restriction_class, cardinality=None):
131+
"""Create combinations of concepts with properties using the given restriction class."""
132+
if cardinality is not None:
133+
return {
134+
restriction_class(filler=c, property=p, cardinality=cardinality)
135+
for c in concepts for p in properties
136+
}
137+
else:
138+
return {
139+
restriction_class(filler=c, property=p)
140+
for c in concepts for p in properties
141+
}
142+
143+
@staticmethod
144+
def _jaccard_similarity(set1: Set, set2: Set) -> float:
145+
"""Calculate Jaccard similarity between two sets."""
146+
if len(set1) == 0 and len(set2) == 0:
147+
return 1.0
148+
intersection = len(set1.intersection(set2))
149+
union = len(set1.union(set2))
150+
return intersection / union if union > 0 else 0.0
151+
152+
@staticmethod
153+
def _f1_set_similarity(set1: Set, set2: Set) -> float:
154+
"""Calculate F1 score between two sets."""
155+
if len(set1) == 0 and len(set2) == 0:
156+
return 1.0
157+
158+
if len(set2) == 0:
159+
return 0.0
160+
161+
true_positives = len(set1.intersection(set2))
162+
precision = true_positives / len(set2) if len(set2) > 0 else 0
163+
recall = true_positives / len(set1) if len(set1) > 0 else 0
164+
165+
if precision + recall == 0:
166+
return 0.0
167+
168+
return 2 * (precision * recall) / (precision + recall)
169+
170+
def _concept_retrieval(self, retriever, concept) -> Tuple[Set[str], float]:
171+
"""Perform concept retrieval and measure runtime."""
172+
start_time = time.time()
173+
instances = {individual.str for individual in retriever.instances(concept)}
174+
runtime = time.time() - start_time
175+
return instances, runtime
176+
177+
def _test_concept_group(self, concepts, group_name):
178+
"""Test a group of concepts and assert perfect scores."""
179+
if not concepts:
180+
pytest.skip(f"No {group_name} concepts to test")
181+
182+
# Sample a few concepts for testing (to keep test runtime reasonable)
183+
test_concepts = list(concepts)[:3] if len(concepts) > 3 else list(concepts)
184+
185+
for concept in test_concepts:
186+
# Get symbolic retrieval results
187+
symbolic_results, symbolic_time = self._concept_retrieval(self.symbolic_kb, concept)
188+
189+
# Get neural retrieval results
190+
neural_results, neural_time = self._concept_retrieval(self.neural_owl_reasoner, concept)
191+
192+
# Calculate similarities
193+
jaccard_sim = self._jaccard_similarity(symbolic_results, neural_results)
194+
f1_sim = self._f1_set_similarity(symbolic_results, neural_results)
195+
196+
# Assert perfect scores
197+
assert jaccard_sim == 1.0, (
198+
f"Jaccard similarity for {group_name} concept {concept} is {jaccard_sim}, "
199+
f"expected 1.0. Symbolic: {symbolic_results}, Neural: {neural_results}"
200+
)
201+
assert f1_sim == 1.0, (
202+
f"F1 score for {group_name} concept {concept} is {f1_sim}, "
203+
f"expected 1.0. Symbolic: {symbolic_results}, Neural: {neural_results}"
204+
)
205+
206+
def test_named_concepts(self):
207+
"""Test retrieval performance on named concepts."""
208+
self._test_concept_group(self.nc, "named concepts")
209+
210+
def test_negated_named_concepts(self):
211+
"""Test retrieval performance on negated named concepts."""
212+
self._test_concept_group(self.nnc, "negated named concepts")
213+
214+
def test_union_concepts(self):
215+
"""Test retrieval performance on union concepts."""
216+
self._test_concept_group(self.unions_nc_star, "union concepts")
217+
218+
def test_intersection_concepts(self):
219+
"""Test retrieval performance on intersection concepts."""
220+
self._test_concept_group(self.intersections_nc_star, "intersection concepts")
221+
222+
def test_existential_restrictions(self):
223+
"""Test retrieval performance on existential restrictions."""
224+
self._test_concept_group(self.exist_nc_star, "existential restrictions")
225+
226+
def test_universal_restrictions(self):
227+
"""Test retrieval performance on universal restrictions."""
228+
self._test_concept_group(self.for_all_nc_star, "universal restrictions")
229+
230+
def test_min_cardinality_restrictions(self):
231+
"""Test retrieval performance on minimum cardinality restrictions."""
232+
self._test_concept_group(self.min_cardinality_nc_star, "minimum cardinality restrictions")
233+
234+
def test_max_cardinality_restrictions(self):
235+
"""Test retrieval performance on maximum cardinality restrictions."""
236+
self._test_concept_group(self.max_cardinality_nc_star, "maximum cardinality restrictions")
237+
238+
def test_existential_with_nominals(self):
239+
"""Test retrieval performance on existential restrictions with nominals."""
240+
self._test_concept_group(self.exist_nominals, "existential restrictions with nominals")
241+
242+
def test_overall_performance(self):
243+
"""Test overall performance across all concept types."""
244+
# Collect all concepts
245+
all_concepts = list(chain(
246+
list(self.nc)[:2], # Sample 2 from each group
247+
list(self.nnc)[:2],
248+
list(self.unions_nc_star)[:2],
249+
list(self.intersections_nc_star)[:2],
250+
list(self.exist_nc_star)[:2],
251+
list(self.for_all_nc_star)[:2],
252+
list(self.min_cardinality_nc_star)[:2],
253+
list(self.max_cardinality_nc_star)[:2],
254+
list(self.exist_nominals)[:2]
255+
))
256+
257+
total_jaccard = 0.0
258+
total_f1 = 0.0
259+
count = 0
260+
261+
for concept in all_concepts:
262+
# Get retrieval results
263+
symbolic_results, _ = self._concept_retrieval(self.symbolic_kb, concept)
264+
neural_results, _ = self._concept_retrieval(self.neural_owl_reasoner, concept)
265+
266+
# Calculate similarities
267+
jaccard_sim = self._jaccard_similarity(symbolic_results, neural_results)
268+
f1_sim = self._f1_set_similarity(symbolic_results, neural_results)
269+
270+
total_jaccard += jaccard_sim
271+
total_f1 += f1_sim
272+
count += 1
273+
274+
# Calculate averages
275+
avg_jaccard = total_jaccard / count if count > 0 else 0.0
276+
avg_f1 = total_f1 / count if count > 0 else 0.0
277+
278+
# Assert perfect average scores
279+
assert avg_jaccard == 1.0, f"Average Jaccard similarity is {avg_jaccard}, expected 1.0"
280+
assert avg_f1 == 1.0, f"Average F1 score is {avg_f1}, expected 1.0"
281+
282+
print(f"\nOverall Performance Summary:")
283+
print(f"Tested {count} concepts")
284+
print(f"Average Jaccard Similarity: {avg_jaccard:.4f}")
285+
print(f"Average F1 Score: {avg_f1:.4f}")

0 commit comments

Comments
 (0)