Skip to content

Commit 08f27b8

Browse files
authored
feat(Taxonomy): children method parity with parent methods (#467)
1 parent 3252d61 commit 08f27b8

2 files changed

Lines changed: 129 additions & 8 deletions

File tree

src/openfoodfacts/taxonomy.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@ def is_parent_of(self, candidate: "TaxonomyNode") -> bool:
156156
"""
157157
return candidate.is_child_of(self)
158158

159+
def is_child_of_any(self, candidates: Iterable["TaxonomyNode"]) -> bool:
160+
"""Return True if `self` is a child of any of `candidates`, False
161+
otherwise.
162+
163+
:param candidates: an iterable of TaxonomyNodes of the same Taxonomy
164+
"""
165+
for candidate in candidates:
166+
if candidate.is_parent_of(self):
167+
return True
168+
169+
return False
170+
159171
def is_parent_of_any(self, candidates: Iterable["TaxonomyNode"]) -> bool:
160172
"""Return True if `self` is a parent of any of `candidates`, False
161173
otherwise.
@@ -168,6 +180,26 @@ def is_parent_of_any(self, candidates: Iterable["TaxonomyNode"]) -> bool:
168180

169181
return False
170182

183+
def get_children_hierarchy(self) -> List["TaxonomyNode"]:
184+
"""Return the list of all child nodes (direct and indirect)."""
185+
all_children = []
186+
seen: Set[str] = set()
187+
188+
if not self.children:
189+
return []
190+
191+
for self_child in self.children:
192+
if self_child.id not in seen:
193+
all_children.append(self_child)
194+
seen.add(self_child.id)
195+
196+
for child_child in self_child.get_children_hierarchy():
197+
if child_child.id not in seen:
198+
all_children.append(child_child)
199+
seen.add(child_child.id)
200+
201+
return all_children
202+
171203
def get_parents_hierarchy(self) -> List["TaxonomyNode"]:
172204
"""Return the list of all parent nodes (direct and indirect)."""
173205
all_parents = []
@@ -284,6 +316,36 @@ def find_deepest_nodes(self, nodes: List[TaxonomyNode]) -> List[TaxonomyNode]:
284316

285317
return [node for node in nodes if node.id not in excluded]
286318

319+
def is_child_of_any(
320+
self, item: str, candidates: Iterable[str], raises: bool = True
321+
) -> bool:
322+
"""Return True if `item` is child of any candidate, False otherwise.
323+
324+
If the item is not in the taxonomy and raises is False, return False.
325+
326+
:param item: The item to compare
327+
:param candidates: A list of candidates
328+
:param raises: if True, raises a ValueError if item is not in the
329+
taxonomy, defaults to True.
330+
"""
331+
node: TaxonomyNode = self[item]
332+
333+
if node is None:
334+
if raises:
335+
raise ValueError("unknown id in taxonomy: %s", node)
336+
else:
337+
return False
338+
339+
to_check_nodes: Set[TaxonomyNode] = set()
340+
341+
for candidate in candidates:
342+
candidate_node = self[candidate]
343+
344+
if candidate_node is not None:
345+
to_check_nodes.add(candidate_node)
346+
347+
return node.is_child_of_any(to_check_nodes)
348+
287349
def is_parent_of_any(
288350
self, item: str, candidates: Iterable[str], raises: bool = True
289351
) -> bool:

tests/unit/test_taxonomy.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,36 +170,95 @@ def test_create_brand_taxonomy_mapping(self):
170170

171171

172172
class TestTaxonomy:
173+
@pytest.mark.parametrize(
174+
"taxonomy,item,candidates,output",
175+
[
176+
(label_taxonomy, "en:fr-bio-01", {"en:organic"}, True),
177+
(label_taxonomy, "en:organic", {"en:fr-bio-01"}, False),
178+
(label_taxonomy, "en:fr-bio-01", [], False),
179+
(label_taxonomy, "en:no-gluten", {"en:organic"}, False),
180+
(
181+
label_taxonomy,
182+
"en:no-gluten",
183+
{"en:organic", "en:no-additives", "en:vegan"},
184+
False,
185+
),
186+
(
187+
label_taxonomy,
188+
"en:fr-bio-16",
189+
{"en:organic", "en:no-gluten", "en:no-additives", "en:vegan"},
190+
True,
191+
),
192+
],
193+
)
194+
def test_is_child_of_any(
195+
self, taxonomy: Taxonomy, item: str, candidates: list[str], output: bool
196+
):
197+
assert taxonomy.is_child_of_any(item, candidates) is output
198+
199+
def test_is_child_of_any_unknown_item(self):
200+
with pytest.raises(ValueError):
201+
label_taxonomy.is_child_of_any("unknown-id", set())
202+
173203
@pytest.mark.parametrize(
174204
"taxonomy,item,candidates,output",
175205
[
176206
(label_taxonomy, "en:organic", {"en:fr-bio-01"}, True),
177207
(label_taxonomy, "en:fr-bio-01", {"en:organic"}, False),
178208
(label_taxonomy, "en:fr-bio-01", [], False),
179-
(label_taxonomy, "en:organic", {"en:gluten-free"}, False),
209+
(label_taxonomy, "en:organic", {"en:no-gluten"}, False),
180210
(
181211
label_taxonomy,
182212
"en:organic",
183-
{"en:gluten-free", "en:no-additives", "en:vegan"},
213+
{"en:no-gluten", "en:no-additives", "en:vegan"},
184214
False,
185215
),
186216
(
187217
label_taxonomy,
188218
"en:organic",
189-
{"en:gluten-free", "en:no-additives", "en:fr-bio-16"},
219+
{"en:no-gluten", "en:no-additives", "en:fr-bio-16"},
190220
True,
191221
),
192222
],
193223
)
194-
def test_is_child_of_any(
195-
self, taxonomy: Taxonomy, item: str, candidates: list, output: bool
224+
def test_is_parent_of_any(
225+
self, taxonomy: Taxonomy, item: str, candidates: list[str], output: bool
196226
):
197227
assert taxonomy.is_parent_of_any(item, candidates) is output
198228

199-
def test_is_child_of_any_unknwon_item(self):
229+
def test_is_parent_of_any_unknown_item(self):
200230
with pytest.raises(ValueError):
201231
label_taxonomy.is_parent_of_any("unknown-id", set())
202232

233+
@pytest.mark.parametrize(
234+
"taxonomy,item,output",
235+
[
236+
(category_taxonomy, "en:brown-camargue-rices", set()),
237+
(
238+
category_taxonomy,
239+
"en:cooked-brown-rices",
240+
{"en:unsalted-cooked-brown-rices"},
241+
),
242+
(
243+
category_taxonomy,
244+
"en:brown-rices",
245+
{
246+
"en:brown-jasmine-rices",
247+
"en:brown-basmati-rices",
248+
"en:brown-camargue-rices",
249+
"en:cooked-brown-rices",
250+
"en:unsalted-cooked-brown-rices",
251+
},
252+
),
253+
],
254+
)
255+
def test_get_children_hierarchy(
256+
self, taxonomy: Taxonomy, item: str, output: set[str]
257+
):
258+
node = taxonomy[item]
259+
children_list = node.get_children_hierarchy()
260+
assert set((x.id for x in children_list)) == output
261+
203262
@pytest.mark.parametrize(
204263
"taxonomy,item,output",
205264
[
@@ -228,8 +287,8 @@ def test_get_parents_hierarchy(
228287
self, taxonomy: Taxonomy, item: str, output: set[str]
229288
):
230289
node = taxonomy[item]
231-
parents = node.get_parents_hierarchy()
232-
assert set((x.id for x in parents)) == output
290+
parents_list = node.get_parents_hierarchy()
291+
assert set((x.id for x in parents_list)) == output
233292

234293
@pytest.mark.parametrize(
235294
"taxonomy,items,output",

0 commit comments

Comments
 (0)