Skip to content

Commit 933dfd4

Browse files
committed
fix: more type dealing with none vs nan
1 parent df1c82f commit 933dfd4

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

unstructured_inference/inference/layoutelement.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,24 @@ def __post_init__(self):
3939
for attr in ("element_probs", "element_class_ids", "texts"):
4040
if getattr(self, attr).size == 0 and element_size:
4141
setattr(self, attr, np.array([None] * element_size))
42+
4243
self.element_probs = self.element_probs.astype(float)
4344

45+
def __eq__(self, other: LayoutElements) -> bool:
46+
mask = ~np.isnan(self.element_probs)
47+
other_mask = ~np.isnan(other.element_probs)
48+
return (
49+
np.array_equal(self.element_coords, other.element_coords)
50+
and np.array_equal(self.texts, other.texts)
51+
and np.array_equal(mask, other_mask)
52+
and np.array_equal(self.element_probs[mask], other.element_probs[mask])
53+
and (
54+
[self.element_class_id_map[idx] for idx in self.element_class_ids]
55+
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
56+
)
57+
and self.source == other.source
58+
)
59+
4460
def slice(self, indices) -> LayoutElements:
4561
"""slice and return only selected indices"""
4662
return LayoutElements(
@@ -87,7 +103,7 @@ def as_list(self):
87103
if class_id is not None and self.element_class_id_map
88104
else None
89105
),
90-
prob=prob,
106+
prob=None if np.isnan(prob) else prob,
91107
source=self.source,
92108
)
93109
for (x1, y1, x2, y2), text, prob, class_id in zip(
@@ -114,9 +130,10 @@ def from_list(cls, elements: list):
114130
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
115131
texts.append(element.text)
116132
class_probs.append(element.prob)
117-
class_types[i] = element.type
133+
class_types[i] = element.type or "None"
118134

119135
unique_ids, class_ids = np.unique(class_types, return_inverse=True)
136+
unique_ids[unique_ids == "None"] = None
120137

121138
return cls(
122139
element_coords=coords,

0 commit comments

Comments
 (0)