@@ -39,8 +39,24 @@ def __post_init__(self):
3939 for attr in ("element_probs" , "element_class_ids" , "texts" ):
4040 if getattr (self , attr ).size == 0 and element_size :
4141 setattr (self , attr , np .array ([None ] * element_size ))
42+
4243 self .element_probs = self .element_probs .astype (float )
4344
45+ def __eq__ (self , other : LayoutElements ) -> bool :
46+ mask = ~ np .isnan (self .element_probs )
47+ other_mask = ~ np .isnan (other .element_probs )
48+ return (
49+ np .array_equal (self .element_coords , other .element_coords )
50+ and np .array_equal (self .texts , other .texts )
51+ and np .array_equal (mask , other_mask )
52+ and np .array_equal (self .element_probs [mask ], other .element_probs [mask ])
53+ and (
54+ [self .element_class_id_map [idx ] for idx in self .element_class_ids ]
55+ == [other .element_class_id_map [idx ] for idx in other .element_class_ids ]
56+ )
57+ and self .source == other .source
58+ )
59+
4460 def slice (self , indices ) -> LayoutElements :
4561 """slice and return only selected indices"""
4662 return LayoutElements (
@@ -87,7 +103,7 @@ def as_list(self):
87103 if class_id is not None and self .element_class_id_map
88104 else None
89105 ),
90- prob = prob ,
106+ prob = None if np . isnan ( prob ) else prob ,
91107 source = self .source ,
92108 )
93109 for (x1 , y1 , x2 , y2 ), text , prob , class_id in zip (
@@ -114,9 +130,10 @@ def from_list(cls, elements: list):
114130 coords [i ] = [element .bbox .x1 , element .bbox .y1 , element .bbox .x2 , element .bbox .y2 ]
115131 texts .append (element .text )
116132 class_probs .append (element .prob )
117- class_types [i ] = element .type
133+ class_types [i ] = element .type or "None"
118134
119135 unique_ids , class_ids = np .unique (class_types , return_inverse = True )
136+ unique_ids [unique_ids == "None" ] = None
120137
121138 return cls (
122139 element_coords = coords ,
0 commit comments