@@ -235,20 +235,22 @@ def split_rows(
235235 prev_row = next (it )
236236 first_key = prev_key = get_key (prev_row )
237237 first_group = cur_group = [prev_row ]
238- yield cur_group
239238 for row in it :
240239 key = get_key (row )
241240 if next_key (prev_key ) == key :
242241 cur_group .append (row )
243242 else :
244- cur_group = [row ]
245243 yield cur_group
244+ cur_group = [row ]
246245 prev_row , prev_key = row , key
247246 if next_key (prev_key ) == first_key :
248247 if first_group is cur_group :
249248 first_group .append (rows [0 ])
249+ yield cur_group
250250 else :
251251 first_group [:0 ] = cur_group
252+ else :
253+ yield cur_group
252254
253255 def __new__ (
254256 cls ,
@@ -278,29 +280,50 @@ def __new__(
278280 groups_hc [f"h={ row .H } _c={ row .C } " ].append (row )
279281
280282 ret : list [DistanceDataset ] = []
281- for groups , ( key , key_next ) in (
282- (groups_hv , ( cls .key_c , cls .key_c_next ) ),
283- (groups_vc , ( cls .key_h , cls .key_h_next ) ),
284- (groups_hc , ( cls .key_v , cls .key_v_next ) ),
283+ for groups , key , key_next in (
284+ (groups_hv , cls .key_c , cls .key_c_next ),
285+ (groups_vc , cls .key_h , cls .key_h_next ),
286+ (groups_hc , cls .key_v , cls .key_v_next ),
285287 ):
286288 for name , group in groups .items ():
287289 group .sort (key = key )
290+ subsets : list [list [MunsellRow ]] = []
288291 for group in list (cls .split_rows (group , key , key_next )):
289- if len (group ) > min_subset_size : # > and not >= !
292+ subsets .append (group )
293+ if subsets :
294+ if (
295+ sum (len (subset ) for subset in subsets )
296+ > min_subset_size
297+ ): # filter bu number of pairs. > and not >= !
290298 ret .append (
291- cls .group_as_dataset (f"{ version } -{ name } " , group )
299+ cls .subsets_as_dataset (
300+ f"{ version } -{ name } " , subsets
301+ )
292302 )
293303 return ret
294304
295305 @staticmethod
296- def group_as_dataset (key : str , rows : list [MunsellRow ]) -> DistanceDataset :
306+ def subsets_as_dataset (
307+ key : str , subsets : list [list [MunsellRow ]]
308+ ) -> DistanceDataset :
297309 from vsl_ial .cs import whitepoints_cie1931
298310 from vsl_ial .cs .ciexyy import CIExyY
299311
300- n = len (rows ) - 1
312+ pairs : list [tuple [int , int ]] = []
313+ shift = 0
314+ for rows in subsets :
315+ n = len (rows )
316+ pairs .extend (
317+ zip (
318+ range (shift + 0 , shift + n - 1 ),
319+ range (shift + 1 , shift + n ),
320+ )
321+ )
322+ shift += len (rows )
301323
302324 xyY = np .asarray (
303- [(row .x , row .y , row .Y * 0.01 ) for row in rows ], dtype = np .float64
325+ [(row .x , row .y , row .Y * 0.01 ) for rows in subsets for row in rows ],
326+ dtype = np .float64 ,
304327 )
305328 xyz = CIExyY (None ).to_XYZ (None , xyY )
306329
@@ -312,7 +335,8 @@ def group_as_dataset(key: str, rows: list[MunsellRow]) -> DistanceDataset:
312335 Nc = 1.0 ,
313336 F = 1.0 ,
314337 illuminant = whitepoints_cie1931 .C ,
315- dv = np .full (shape = (n ,), fill_value = np .float64 (1.0 )),
316- pairs = list (zip (range (n ), range (1 , n + 1 ))),
338+ # 1.0 is a perceptive step, we don't know its exact value
339+ dv = np .full (shape = (len (pairs ),), fill_value = np .float64 (1.0 )),
340+ pairs = pairs ,
317341 xyz = xyz ,
318342 )
0 commit comments