Skip to content

Commit 646582e

Browse files
committed
fix 7f759f0 multi-label classification, again
1 parent d5b7093 commit 646582e

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

src/data.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,10 @@ def init(_data_dir,
147147

148148
def catalog_overlaps(data):
149149
data.sort(key=lambda x: x['ticks'][0])
150-
k = [True]*len(data)
151150
for i in range(len(data)):
152151
for j in range(i-1):
153-
if not k[j]: continue
154-
if data[j]['ticks'][1] < data[i]['ticks'][0]:
155-
k[j] = False
156-
continue
157-
data[i]['overlaps'].append(j)
152+
if data[j]['file'] == data[i]['file'] and data[j]['ticks'][1] > data[i]['ticks'][0]:
153+
data[i]['overlaps'].append(j)
158154

159155
def prepare_data_index(shiftby,
160156
labels_touse, kinds_touse,
@@ -345,7 +341,7 @@ def prepare_data_index(shiftby,
345341
for set_index in ['validation', 'testing', 'training']:
346342
print("num "+set_index+" labels")
347343
if set_index != 'testing':
348-
catalog_overlaps(data_index[set_index])
344+
if loss=="overlapped": catalog_overlaps(data_index[set_index])
349345
labels = [sound['label'] for sound in data_index[set_index]]
350346
for uniqlabel in sorted(set(labels)):
351347
print('%8d %s' % (sum(label==uniqlabel for label in labels), uniqlabel))
@@ -385,7 +381,7 @@ def prepare_data_index(shiftby,
385381
testing_max_sounds,
386382
replace=False).tolist()
387383
if set_index == 'testing':
388-
catalog_overlaps(data_index['testing'])
384+
if loss=="overlapped": catalog_overlaps(data_index['testing'])
389385
labels = [sound['label'] for sound in data_index['testing']]
390386
for uniqlabel in sorted(set(labels)):
391387
print('%7d %s' % (sum(label==uniqlabel for label in labels), uniqlabel))

0 commit comments

Comments
 (0)