Skip to content

Commit 4ecc883

Browse files
committed
Improve train detection
1 parent 9185c1b commit 4ecc883

File tree

4 files changed

+181
-70
lines changed

4 files changed

+181
-70
lines changed

scripts/train_detector.py

Lines changed: 133 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -92,49 +92,93 @@ def _detect_by_ocr(self, gray, h, w, hsv):
9292
if width < 6:
9393
continue
9494

95-
# Sub-column splitting for very wide groups only
96-
# Train IDs are typically 15-25 pixels wide
97-
if width > 35:
98-
sub_cols = [(sub_x, min(sub_x + 20, x2))
99-
for sub_x in range(x1, x2, 18)]
95+
# Sub-column splitting for groups wider than a single train ID
96+
# A typical train ID is 10-18 pixels wide
97+
# Groups 20+ pixels wide may contain multiple trains
98+
if width > 20:
99+
# Try to find the best split point (largest internal gap)
100+
sub_cols = self._split_wide_group(x1, x2, col_sums)
100101
else:
101102
sub_cols = [(x1, x2)]
102103

103104
for sub_x1, sub_x2 in sub_cols:
104-
train_id = self._ocr_column(band, dark_mask, sub_x1, sub_x2, band_h, track)
105-
if train_id:
106-
center_x = (sub_x1 + sub_x2) // 2
107-
trains.append({
108-
'id': train_id,
109-
'x': center_x,
110-
'y': y_min + band_h // 2,
111-
'track': track,
112-
'confidence': 'high'
113-
})
105+
train_ids = self._ocr_column(band, dark_mask, sub_x1, sub_x2, band_h, track)
106+
if train_ids:
107+
# Distribute multiple train IDs across the column width
108+
sub_width = sub_x2 - sub_x1
109+
n_trains = len(train_ids)
110+
for i, train_id in enumerate(train_ids):
111+
if n_trains == 1:
112+
center_x = (sub_x1 + sub_x2) // 2
113+
else:
114+
# Space trains evenly across the column
115+
center_x = sub_x1 + int(sub_width * (i + 0.5) / n_trains)
116+
trains.append({
117+
'id': train_id,
118+
'x': center_x,
119+
'y': y_min + band_h // 2,
120+
'track': track,
121+
'confidence': 'high'
122+
})
114123

115124
# Also detect colored text labels (yellow, green)
116125
colored_trains = self._detect_colored_labels(band, band_hsv, band_h, y_min, track)
117126
trains.extend(colored_trains)
118127

119128
return self._deduplicate(trains)
120129

121-
def _group_columns(self, columns):
122-
"""Group adjacent columns."""
130+
def _split_wide_group(self, x1, x2, col_sums):
131+
"""Split a wide column group at the best split point.
132+
133+
For groups that may contain multiple trains, find the column with
134+
the minimum sum (gap between trains) and split there.
135+
"""
136+
width = x2 - x1
137+
if width <= 20:
138+
return [(x1, x2)]
139+
140+
# Find the minimum column sum in the middle portion of the group
141+
# (don't split at the edges)
142+
margin = max(6, width // 5)
143+
mid_start = x1 + margin
144+
mid_end = x2 - margin
145+
146+
if mid_start >= mid_end:
147+
return [(x1, x2)]
148+
149+
# Find the column with minimum sum (likely a gap between trains)
150+
min_sum = float('inf')
151+
split_x = None
152+
for x in range(mid_start, mid_end):
153+
if col_sums[x] < min_sum:
154+
min_sum = col_sums[x]
155+
split_x = x
156+
157+
# Only split if we found a true gap (zero column sum)
158+
# Gaps within train IDs typically have sum > 0
159+
# Gaps between adjacent trains typically have sum = 0
160+
if split_x and min_sum == 0:
161+
return [(x1, split_x), (split_x + 1, x2)]
162+
else:
163+
return [(x1, x2)]
164+
165+
def _group_columns(self, columns, gap_threshold=10):
166+
"""Group adjacent columns. Gap threshold determines when to split groups."""
123167
if len(columns) == 0:
124168
return []
125169
groups = []
126170
start = columns[0]
127171
prev = start
128172
for x in columns[1:]:
129-
if x - prev > 10:
173+
if x - prev > gap_threshold:
130174
groups.append((start, prev))
131175
start = x
132176
prev = x
133177
groups.append((start, prev))
134178
return groups
135179

136180
def _ocr_column(self, band, dark_mask, x1, x2, band_h, track):
137-
"""OCR a single column - ONE preprocessing pass."""
181+
"""OCR a single column - returns list of train IDs found."""
138182
pad = 3
139183
roi_x1 = max(0, x1 - pad)
140184
roi_x2 = min(band.shape[1], x2 + pad)
@@ -144,25 +188,36 @@ def _ocr_column(self, band, dark_mask, x1, x2, band_h, track):
144188
row_sums = col_mask.sum(axis=1)
145189
text_rows = np.where(row_sums > 0)[0]
146190
if len(text_rows) < 5:
147-
return None
191+
return []
148192

149193
y1 = max(0, text_rows[0] - 2)
150194
y2 = min(band_h, text_rows[-1] + 2)
151195

152196
# Skip station label regions
153197
if track == 'lower':
198+
# Lower track: station labels are at top of band
154199
station_cutoff = int(band_h * 0.15)
155200
if y1 < station_cutoff:
156201
y1 = station_cutoff
157-
elif track == 'upper':
158-
# Allow more of the band - station labels are filtered by name matching
159-
station_cutoff = int(band_h * 0.95)
160-
if y2 > station_cutoff:
161-
y2 = station_cutoff
202+
else:
203+
# Upper track: station labels are at bottom of band (e.g., "Embarcadero")
204+
# Look for a gap of 8+ rows with no text - station labels are after such gaps
205+
gap_start = None
206+
for y in range(int(band_h * 0.70), y2):
207+
if row_sums[y] == 0:
208+
if gap_start is None:
209+
gap_start = y
210+
elif gap_start is not None:
211+
gap_len = y - gap_start
212+
if gap_len >= 8:
213+
# Found significant gap - clip before it
214+
y2 = gap_start
215+
break
216+
gap_start = None
162217

163218
roi = band[y1:y2, roi_x1:roi_x2]
164219
if roi.size == 0 or roi.shape[0] < 20:
165-
return None
220+
return []
166221

167222
# Single OCR pass with Otsu
168223
scale = 4
@@ -171,9 +226,9 @@ def _ocr_column(self, band, dark_mask, x1, x2, band_h, track):
171226

172227
try:
173228
text = pytesseract.image_to_string(roi_bin, config='--psm 6')
174-
return self._extract_train_id(text)
229+
return self._extract_train_ids(text)
175230
except Exception:
176-
return None
231+
return []
177232

178233
def _detect_colored_labels(self, band_gray, band_hsv, band_h, y_offset, track):
179234
"""Detect colored text labels (yellow, green train IDs)."""
@@ -211,15 +266,12 @@ def _detect_colored_labels(self, band_gray, band_hsv, band_h, y_offset, track):
211266
y1 = max(0, text_rows[0] - 2)
212267
y2 = min(band_h, text_rows[-1] + 2)
213268

214-
# Skip station label areas
269+
# Skip station label areas (lower track only)
270+
# Upper track relies on station label filtering in _extract_train_id()
215271
if track == 'lower':
216272
station_cutoff = int(band_h * 0.15)
217273
if y1 < station_cutoff:
218274
y1 = station_cutoff
219-
elif track == 'upper':
220-
station_cutoff = int(band_h * 0.80)
221-
if y2 > station_cutoff:
222-
y2 = station_cutoff
223275

224276
if y2 - y1 < 20:
225277
continue
@@ -390,8 +442,8 @@ def _merge(self, ocr_trains, symbols):
390442

391443
return trains
392444

393-
def _extract_train_id(self, text):
394-
"""Extract train ID from OCR text."""
445+
def _extract_train_ids(self, text):
446+
"""Extract all train IDs from OCR text. Returns list of train IDs."""
395447
lines = [c.strip() for c in text.split('\n') if c.strip()]
396448
combined = ''.join(lines).upper()
397449

@@ -408,6 +460,23 @@ def _extract_train_id(self, text):
408460
combined = combined.replace('Z', '7')
409461
combined = combined.replace('(', '').replace(')', '').replace(' ', '')
410462

463+
# OCR sometimes reads digits as 'RE' - could be 5 or 7
464+
# Try both interpretations and return all unique valid train IDs
465+
# Prefer 5 over 7 since RE is more commonly a misread 5
466+
if 'RE' in combined:
467+
combined_5 = combined.replace('RE', '5')
468+
combined_7 = combined.replace('RE', '7')
469+
matches_5 = TRAIN_ID_PATTERN.findall(combined_5)
470+
matches_7 = TRAIN_ID_PATTERN.findall(combined_7)
471+
# Return 5-versions first, then any unique 7-versions
472+
result = list(matches_5)
473+
for m in matches_7:
474+
if m not in result:
475+
result.append(m)
476+
if result:
477+
return result
478+
combined = combined_5 # Fall through if neither worked
479+
411480
# Fix duplicate leading letters (OCR sometimes doubles them in vertical text)
412481
if len(combined) >= 2 and combined[0].isalpha() and combined[0] == combined[1]:
413482
combined = combined[1:]
@@ -417,7 +486,7 @@ def _extract_train_id(self, text):
417486

418487
matches = TRAIN_ID_PATTERN.findall(combined)
419488
if matches:
420-
return matches[0]
489+
return matches # Return ALL matches
421490

422491
# Position-based 7/T fix for vertical text where 7 is often read as T
423492
if len(combined) >= 5:
@@ -444,12 +513,17 @@ def _extract_train_id(self, text):
444513
fixed[i] = 'O'
445514
matches = TRAIN_ID_PATTERN.findall(''.join(fixed))
446515
if matches:
447-
return matches[0]
516+
return matches # Return ALL matches
448517

449-
return None
518+
return []
519+
520+
def _extract_train_id(self, text):
521+
"""Extract first train ID from OCR text (for backward compatibility)."""
522+
matches = self._extract_train_ids(text)
523+
return matches[0] if matches else None
450524

451525
def _deduplicate(self, trains):
452-
"""Remove duplicates."""
526+
"""Remove duplicates (same train detected twice), but keep bunched trains."""
453527
if not trains:
454528
return trains
455529
trains = sorted(trains, key=lambda t: (t['track'], t['x']))
@@ -459,11 +533,26 @@ def _deduplicate(self, trains):
459533
for existing in unique:
460534
if (existing['track'] == train['track'] and
461535
abs(existing['x'] - train['x']) < 30):
462-
if len(train['id']) > len(existing['id']):
463-
unique.remove(existing)
464-
unique.append(train)
465-
is_dup = True
466-
break
536+
# Only consider duplicates if IDs are similar
537+
# (one is prefix of other, or differ by at most 2 chars)
538+
if self._ids_are_similar(existing['id'], train['id']):
539+
if len(train['id']) > len(existing['id']):
540+
unique.remove(existing)
541+
unique.append(train)
542+
is_dup = True
543+
break
467544
if not is_dup:
468545
unique.append(train)
469546
return unique
547+
548+
def _ids_are_similar(self, id1, id2):
549+
"""Check if two train IDs are likely the same train detected twice."""
550+
# One is a prefix of the other (e.g., "M2034L" and "M2034LL")
551+
if id1.startswith(id2) or id2.startswith(id1):
552+
return True
553+
# Same length and differ by at most 2 characters
554+
if len(id1) == len(id2):
555+
diff = sum(c1 != c2 for c1, c2 in zip(id1, id2))
556+
if diff <= 2:
557+
return True
558+
return False

0 commit comments

Comments
 (0)