@@ -92,49 +92,93 @@ def _detect_by_ocr(self, gray, h, w, hsv):
9292 if width < 6 :
9393 continue
9494
95- # Sub-column splitting for very wide groups only
96- # Train IDs are typically 15-25 pixels wide
97- if width > 35 :
98- sub_cols = [(sub_x , min (sub_x + 20 , x2 ))
99- for sub_x in range (x1 , x2 , 18 )]
95+ # Sub-column splitting for groups wider than a single train ID
96+ # A typical train ID is 10-18 pixels wide
97+ # Groups 20+ pixels wide may contain multiple trains
98+ if width > 20 :
99+ # Try to find the best split point (largest internal gap)
100+ sub_cols = self ._split_wide_group (x1 , x2 , col_sums )
100101 else :
101102 sub_cols = [(x1 , x2 )]
102103
103104 for sub_x1 , sub_x2 in sub_cols :
104- train_id = self ._ocr_column (band , dark_mask , sub_x1 , sub_x2 , band_h , track )
105- if train_id :
106- center_x = (sub_x1 + sub_x2 ) // 2
107- trains .append ({
108- 'id' : train_id ,
109- 'x' : center_x ,
110- 'y' : y_min + band_h // 2 ,
111- 'track' : track ,
112- 'confidence' : 'high'
113- })
105+ train_ids = self ._ocr_column (band , dark_mask , sub_x1 , sub_x2 , band_h , track )
106+ if train_ids :
107+ # Distribute multiple train IDs across the column width
108+ sub_width = sub_x2 - sub_x1
109+ n_trains = len (train_ids )
110+ for i , train_id in enumerate (train_ids ):
111+ if n_trains == 1 :
112+ center_x = (sub_x1 + sub_x2 ) // 2
113+ else :
114+ # Space trains evenly across the column
115+ center_x = sub_x1 + int (sub_width * (i + 0.5 ) / n_trains )
116+ trains .append ({
117+ 'id' : train_id ,
118+ 'x' : center_x ,
119+ 'y' : y_min + band_h // 2 ,
120+ 'track' : track ,
121+ 'confidence' : 'high'
122+ })
114123
115124 # Also detect colored text labels (yellow, green)
116125 colored_trains = self ._detect_colored_labels (band , band_hsv , band_h , y_min , track )
117126 trains .extend (colored_trains )
118127
119128 return self ._deduplicate (trains )
120129
121- def _group_columns (self , columns ):
122- """Group adjacent columns."""
130+ def _split_wide_group (self , x1 , x2 , col_sums ):
131+ """Split a wide column group at the best split point.
132+
133+ For groups that may contain multiple trains, find the column with
134+ the minimum sum (gap between trains) and split there.
135+ """
136+ width = x2 - x1
137+ if width <= 20 :
138+ return [(x1 , x2 )]
139+
140+ # Find the minimum column sum in the middle portion of the group
141+ # (don't split at the edges)
142+ margin = max (6 , width // 5 )
143+ mid_start = x1 + margin
144+ mid_end = x2 - margin
145+
146+ if mid_start >= mid_end :
147+ return [(x1 , x2 )]
148+
149+ # Find the column with minimum sum (likely a gap between trains)
150+ min_sum = float ('inf' )
151+ split_x = None
152+ for x in range (mid_start , mid_end ):
153+ if col_sums [x ] < min_sum :
154+ min_sum = col_sums [x ]
155+ split_x = x
156+
157+ # Only split if we found a true gap (zero column sum)
158+ # Gaps within train IDs typically have sum > 0
159+ # Gaps between adjacent trains typically have sum = 0
160+ if split_x and min_sum == 0 :
161+ return [(x1 , split_x ), (split_x + 1 , x2 )]
162+ else :
163+ return [(x1 , x2 )]
164+
165+ def _group_columns (self , columns , gap_threshold = 10 ):
166+ """Group adjacent columns. Gap threshold determines when to split groups."""
123167 if len (columns ) == 0 :
124168 return []
125169 groups = []
126170 start = columns [0 ]
127171 prev = start
128172 for x in columns [1 :]:
129- if x - prev > 10 :
173+ if x - prev > gap_threshold :
130174 groups .append ((start , prev ))
131175 start = x
132176 prev = x
133177 groups .append ((start , prev ))
134178 return groups
135179
136180 def _ocr_column (self , band , dark_mask , x1 , x2 , band_h , track ):
137- """OCR a single column - ONE preprocessing pass ."""
181+ """OCR a single column - returns list of train IDs found ."""
138182 pad = 3
139183 roi_x1 = max (0 , x1 - pad )
140184 roi_x2 = min (band .shape [1 ], x2 + pad )
@@ -144,25 +188,36 @@ def _ocr_column(self, band, dark_mask, x1, x2, band_h, track):
144188 row_sums = col_mask .sum (axis = 1 )
145189 text_rows = np .where (row_sums > 0 )[0 ]
146190 if len (text_rows ) < 5 :
147- return None
191+ return []
148192
149193 y1 = max (0 , text_rows [0 ] - 2 )
150194 y2 = min (band_h , text_rows [- 1 ] + 2 )
151195
152196 # Skip station label regions
153197 if track == 'lower' :
198+ # Lower track: station labels are at top of band
154199 station_cutoff = int (band_h * 0.15 )
155200 if y1 < station_cutoff :
156201 y1 = station_cutoff
157- elif track == 'upper' :
158- # Allow more of the band - station labels are filtered by name matching
159- station_cutoff = int (band_h * 0.95 )
160- if y2 > station_cutoff :
161- y2 = station_cutoff
202+ else :
203+ # Upper track: station labels are at bottom of band (e.g., "Embarcadero")
204+ # Look for a gap of 8+ rows with no text - station labels are after such gaps
205+ gap_start = None
206+ for y in range (int (band_h * 0.70 ), y2 ):
207+ if row_sums [y ] == 0 :
208+ if gap_start is None :
209+ gap_start = y
210+ elif gap_start is not None :
211+ gap_len = y - gap_start
212+ if gap_len >= 8 :
213+ # Found significant gap - clip before it
214+ y2 = gap_start
215+ break
216+ gap_start = None
162217
163218 roi = band [y1 :y2 , roi_x1 :roi_x2 ]
164219 if roi .size == 0 or roi .shape [0 ] < 20 :
165- return None
220+ return []
166221
167222 # Single OCR pass with Otsu
168223 scale = 4
@@ -171,9 +226,9 @@ def _ocr_column(self, band, dark_mask, x1, x2, band_h, track):
171226
172227 try :
173228 text = pytesseract .image_to_string (roi_bin , config = '--psm 6' )
174- return self ._extract_train_id (text )
229+ return self ._extract_train_ids (text )
175230 except Exception :
176- return None
231+ return []
177232
178233 def _detect_colored_labels (self , band_gray , band_hsv , band_h , y_offset , track ):
179234 """Detect colored text labels (yellow, green train IDs)."""
@@ -211,15 +266,12 @@ def _detect_colored_labels(self, band_gray, band_hsv, band_h, y_offset, track):
211266 y1 = max (0 , text_rows [0 ] - 2 )
212267 y2 = min (band_h , text_rows [- 1 ] + 2 )
213268
214- # Skip station label areas
269+ # Skip station label areas (lower track only)
270+ # Upper track relies on station label filtering in _extract_train_id()
215271 if track == 'lower' :
216272 station_cutoff = int (band_h * 0.15 )
217273 if y1 < station_cutoff :
218274 y1 = station_cutoff
219- elif track == 'upper' :
220- station_cutoff = int (band_h * 0.80 )
221- if y2 > station_cutoff :
222- y2 = station_cutoff
223275
224276 if y2 - y1 < 20 :
225277 continue
@@ -390,8 +442,8 @@ def _merge(self, ocr_trains, symbols):
390442
391443 return trains
392444
393- def _extract_train_id (self , text ):
394- """Extract train ID from OCR text."""
445+ def _extract_train_ids (self , text ):
446+ """Extract all train IDs from OCR text. Returns list of train IDs ."""
395447 lines = [c .strip () for c in text .split ('\n ' ) if c .strip ()]
396448 combined = '' .join (lines ).upper ()
397449
@@ -408,6 +460,23 @@ def _extract_train_id(self, text):
408460 combined = combined .replace ('Z' , '7' )
409461 combined = combined .replace ('(' , '' ).replace (')' , '' ).replace (' ' , '' )
410462
463+ # OCR sometimes reads digits as 'RE' - could be 5 or 7
464+ # Try both interpretations and return all unique valid train IDs
465+ # Prefer 5 over 7 since RE is more commonly a misread 5
466+ if 'RE' in combined :
467+ combined_5 = combined .replace ('RE' , '5' )
468+ combined_7 = combined .replace ('RE' , '7' )
469+ matches_5 = TRAIN_ID_PATTERN .findall (combined_5 )
470+ matches_7 = TRAIN_ID_PATTERN .findall (combined_7 )
471+ # Return 5-versions first, then any unique 7-versions
472+ result = list (matches_5 )
473+ for m in matches_7 :
474+ if m not in result :
475+ result .append (m )
476+ if result :
477+ return result
478+ combined = combined_5 # Fall through if neither worked
479+
411480 # Fix duplicate leading letters (OCR sometimes doubles them in vertical text)
412481 if len (combined ) >= 2 and combined [0 ].isalpha () and combined [0 ] == combined [1 ]:
413482 combined = combined [1 :]
@@ -417,7 +486,7 @@ def _extract_train_id(self, text):
417486
418487 matches = TRAIN_ID_PATTERN .findall (combined )
419488 if matches :
420- return matches [ 0 ]
489+ return matches # Return ALL matches
421490
422491 # Position-based 7/T fix for vertical text where 7 is often read as T
423492 if len (combined ) >= 5 :
@@ -444,12 +513,17 @@ def _extract_train_id(self, text):
444513 fixed [i ] = 'O'
445514 matches = TRAIN_ID_PATTERN .findall ('' .join (fixed ))
446515 if matches :
447- return matches [ 0 ]
516+ return matches # Return ALL matches
448517
449- return None
518+ return []
519+
520+ def _extract_train_id (self , text ):
521+ """Extract first train ID from OCR text (for backward compatibility)."""
522+ matches = self ._extract_train_ids (text )
523+ return matches [0 ] if matches else None
450524
451525 def _deduplicate (self , trains ):
452- """Remove duplicates."""
526+ """Remove duplicates (same train detected twice), but keep bunched trains ."""
453527 if not trains :
454528 return trains
455529 trains = sorted (trains , key = lambda t : (t ['track' ], t ['x' ]))
@@ -459,11 +533,26 @@ def _deduplicate(self, trains):
459533 for existing in unique :
460534 if (existing ['track' ] == train ['track' ] and
461535 abs (existing ['x' ] - train ['x' ]) < 30 ):
462- if len (train ['id' ]) > len (existing ['id' ]):
463- unique .remove (existing )
464- unique .append (train )
465- is_dup = True
466- break
536+ # Only consider duplicates if IDs are similar
537+ # (one is prefix of other, or differ by at most 2 chars)
538+ if self ._ids_are_similar (existing ['id' ], train ['id' ]):
539+ if len (train ['id' ]) > len (existing ['id' ]):
540+ unique .remove (existing )
541+ unique .append (train )
542+ is_dup = True
543+ break
467544 if not is_dup :
468545 unique .append (train )
469546 return unique
547+
548+ def _ids_are_similar (self , id1 , id2 ):
549+ """Check if two train IDs are likely the same train detected twice."""
550+ # One is a prefix of the other (e.g., "M2034L" and "M2034LL")
551+ if id1 .startswith (id2 ) or id2 .startswith (id1 ):
552+ return True
553+ # Same length and differ by at most 2 characters
554+ if len (id1 ) == len (id2 ):
555+ diff = sum (c1 != c2 for c1 , c2 in zip (id1 , id2 ))
556+ if diff <= 2 :
557+ return True
558+ return False
0 commit comments