Skip to content

Commit 6820cc8

Browse files
authored
Merge pull request #43 from Kitware/tas/original-fs
Add original sample rate products and time buffer
2 parents 740cba2 + a7d2732 commit 6820cc8

2 files changed

Lines changed: 186 additions & 7 deletions

File tree

batbot/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from batbot import spectrogram # NOQA
3535

36-
VERSION = '0.1.4'
36+
VERSION = '0.1.5'
3737
version = VERSION
3838
__version__ = VERSION
3939

@@ -69,6 +69,8 @@ def pipeline(
6969
force_overwrite=False,
7070
quiet=False,
7171
plot_uncompressed_amplitude=False,
72+
include_original_sr=False,
73+
time_buffer_ms=1.0,
7274
debug=False,
7375
):
7476
"""
@@ -109,6 +111,8 @@ def pipeline(
109111
force_overwrite=force_overwrite,
110112
quiet=quiet,
111113
plot_uncompressed_amplitude=plot_uncompressed_amplitude,
114+
include_original_sr=include_original_sr,
115+
time_buffer_ms=time_buffer_ms,
112116
debug=debug,
113117
)
114118

@@ -308,6 +312,8 @@ def example():
308312
fast_mode=False,
309313
force_overwrite=True,
310314
plot_uncompressed_amplitude=True,
315+
include_original_sr=True,
316+
time_buffer_ms=5.0,
311317
)
312318
stop_time = time.time()
313319
print('Example pipeline completed in {} seconds.'.format(stop_time - start_time))

batbot/spectrogram/__init__.py

Lines changed: 179 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def load_stft(
253253
win_length=256,
254254
hop_length=16,
255255
fast_mode=False,
256+
use_original_sr=False,
256257
):
257258
assert exists(wav_filepath)
258259
log.debug(f'Computing spectrogram on {wav_filepath}')
@@ -265,7 +266,18 @@ def load_stft(
265266
raise OSError(f'Error loading file: {e}')
266267

267268
# Resample the waveform
268-
waveform = librosa.resample(waveform_, orig_sr=orig_sr, target_sr=sr)
269+
if not use_original_sr:
270+
waveform = librosa.resample(waveform_, orig_sr=orig_sr, target_sr=sr)
271+
else:
272+
waveform = waveform_
273+
# # define a next-power-of-2 factor to increase window and hop length
274+
# sr_factor = np.pow(2, np.ceil(np.log2(orig_sr / sr)))
275+
sr_factor = orig_sr / sr
276+
277+
sr *= sr_factor
278+
n_fft = int(np.round(n_fft * sr_factor))
279+
win_length = int(np.round(win_length * sr_factor))
280+
hop_length = int(np.round(hop_length * sr_factor))
269281

270282
# TODO: signal processing: remove DC offset, time window edges of waveform
271283

@@ -292,7 +304,7 @@ def load_stft(
292304
band_min = bands[index] - delta_f / 2.0
293305
band_max = bands[index] + delta_f / 2.0
294306
# accept bands with any part of their range within interval [FREQ_MIN, FREQ_MAX]
295-
if FREQ_MIN <= band_max and band_min <= FREQ_MAX:
307+
if FREQ_MIN <= band_max and (use_original_sr or band_min <= FREQ_MAX):
296308
goods.append(index)
297309
min_index = min(goods)
298310
max_index = max(goods)
@@ -592,14 +604,15 @@ def tighten_ranges(
592604
duration,
593605
skew_stddev=2.0,
594606
min_duration_ms=2.0,
607+
extra_buffer_pix=0.0,
595608
output_path='.',
596609
quiet=False,
597610
):
598611
minimum_duration = int(np.around(stft_db.shape[1] / (duration * 1e3) * min_duration_ms))
599612

600613
stride_ = 2
601614
window = int(window)
602-
buffer = int(round(window / stride_ / 2))
615+
buffer = int(round(window / stride_ / 2)) + extra_buffer_pix
603616

604617
ranges_ = []
605618
for index, (start, stop) in tqdm.tqdm(list(enumerate(ranges)), disable=quiet):
@@ -1409,6 +1422,8 @@ def compute_wrapper(
14091422
bitdepth=16,
14101423
mask_secondary_effects=False,
14111424
plot_uncompressed_amplitude=False,
1425+
include_original_sr=False,
1426+
time_buffer_ms=1.0,
14121427
debug=False,
14131428
**kwargs,
14141429
):
@@ -1472,7 +1487,7 @@ def compute_wrapper(
14721487
warnings.simplefilter('ignore', category=DeprecationWarning)
14731488
# ignore warning due to aifc deprecation
14741489
stft_db, waveplot, sr, bands, duration, freq_offset, time_vec, orig_sr, max_band_idx = (
1475-
load_stft(wav_filepath, fast_mode=fast_mode)
1490+
load_stft(wav_filepath, fast_mode=fast_mode, use_original_sr=False)
14761491
)
14771492

14781493
# Apply a dynamic range to a fixed dB range
@@ -1593,9 +1608,18 @@ def compute_wrapper(
15931608
else:
15941609

15951610
# Tighten the ranges by looking for substantial right-side skew (use stride for a smaller sampling window)
1611+
extra_buffer_pix = int(max(0.0, (time_buffer_ms - 1.0) / x_step_ms))
15961612
ranges = tighten_ranges(
1597-
stft_db, ranges, stride, duration, output_path=debug_path, quiet=quiet
1613+
stft_db,
1614+
ranges,
1615+
stride,
1616+
duration,
1617+
output_path=debug_path,
1618+
extra_buffer_pix=extra_buffer_pix,
1619+
quiet=quiet,
15981620
)
1621+
# Merge all range segments into contiguous range blocks
1622+
ranges = merge_ranges(ranges, stft_db.shape[1])
15991623

16001624
# Extract chirp metrics and metadata
16011625
segments = {
@@ -1731,7 +1755,7 @@ def compute_wrapper(
17311755
metadata.update(slopes)
17321756

17331757
# Trim segment around the bat call with a small buffer
1734-
buffer_ms = 1.0
1758+
buffer_ms = time_buffer_ms
17351759
buffer_pix = int(round(buffer_ms / x_step_ms))
17361760
trim_begin = max(0, min(segment.shape[1], call_begin[1] - buffer_pix))
17371761
trim_end = max(0, min(segment.shape[1], call_end[1] + buffer_pix))
@@ -1839,6 +1863,81 @@ def compute_wrapper(
18391863
[cv2.IMWRITE_TIFF_COMPRESSION, 1],
18401864
)
18411865

1866+
# If desired, also generate uncompressed and compressed spectrograms
1867+
# without reducing the sample rate. These should have similar step
1868+
# size in time and frequency
1869+
if include_original_sr:
1870+
with warnings.catch_warnings():
1871+
warnings.simplefilter('ignore', category=DeprecationWarning)
1872+
# ignore warning due to aifc deprecation
1873+
(
1874+
stft_db_origsr,
1875+
_,
1876+
_,
1877+
bands_origsr,
1878+
duration_origsr,
1879+
_,
1880+
time_vec_origsr,
1881+
orig_sr,
1882+
max_band_idx_origsr,
1883+
) = load_stft(wav_filepath, fast_mode=fast_mode, use_original_sr=True)
1884+
# Apply a dynamic range to a fixed dB range
1885+
stft_db_origsr = gain_stft(stft_db_origsr, max_band_idx=max_band_idx_origsr)
1886+
1887+
# Bin the floating point data to X-bit integers (X=8 or X=16)
1888+
stft_db_origsr = normalize_stft(stft_db_origsr, None, dtype)
1889+
1890+
# Vertically flip the spectrogram, lowest frequencies on the bottom
1891+
# Convert to a C++ contiguous array for OpenCV
1892+
stft_db_origsr = np.ascontiguousarray(stft_db_origsr[::-1, :])
1893+
bands_origsr = bands_origsr[::-1]
1894+
y_step_freq_origsr = float(bands_origsr[0] - bands_origsr[1])
1895+
x_step_ms_origsr = float(1e3 * (time_vec_origsr[1] - time_vec_origsr[0]))
1896+
bands_origsr = np.around(bands_origsr).astype(np.int32).tolist()
1897+
1898+
# Allow up to 5% change in step sizes or frequency bands when comparing
1899+
# to band-limited spectrogram.
1900+
tol = 5e-2
1901+
assert (
1902+
np.abs(x_step_ms - x_step_ms_origsr) / x_step_ms <= tol
1903+
), 'time step changed unexpectedly much when using original sample rate'
1904+
assert (
1905+
np.abs(y_step_freq - y_step_freq_origsr) / y_step_freq <= tol
1906+
), 'frequency step changed unexpectedly much when using original sample rate'
1907+
if orig_sr >= sr:
1908+
assert all(
1909+
[np.abs(x - y) / x <= tol for x, y in zip(bands, bands_origsr[-len(bands) :])]
1910+
), 'lower frequency bands changed unexpectedly much when using original sample rate'
1911+
else:
1912+
assert all(
1913+
[
1914+
np.abs(x - y) / x <= tol
1915+
for x, y in zip(bands[-len(bands_origsr) :], bands_origsr)
1916+
]
1917+
), 'lower frequency bands changed unexpectedly much when using original sample rate'
1918+
1919+
# Create compressed spectrogram using segment start and stop times
1920+
segments_origsr = []
1921+
for segment_meta in metas:
1922+
start = max(0, int(np.round(segment_meta['segment start.ms'] / x_step_ms_origsr)))
1923+
end = min(
1924+
stft_db_origsr.shape[1],
1925+
int(np.round(segment_meta['segment end.ms'] / x_step_ms_origsr)),
1926+
)
1927+
segments_origsr.append(stft_db_origsr[:, start:end])
1928+
segments['stft_db_origsr'] = np.concatenate(segments_origsr, axis=1)
1929+
1930+
# Save some metadata
1931+
meta_origsr = {
1932+
'sr.hz': int(orig_sr),
1933+
'duration.ms': round(duration_origsr * 1e3, 3),
1934+
'frequencies': {
1935+
'min.hz': int(FREQ_MIN),
1936+
'max.hz': int(max(bands_origsr)),
1937+
'pixels.hz': bands_origsr,
1938+
},
1939+
}
1940+
18421941
output_paths = []
18431942
compressed_paths = []
18441943
mask_paths = []
@@ -1849,6 +1948,10 @@ def compute_wrapper(
18491948
datas = [
18501949
(output_paths, 'jpg', stft_db),
18511950
]
1951+
if not fast_mode and include_original_sr:
1952+
datas += [
1953+
(output_paths, 'origsr.jpg', stft_db_origsr),
1954+
]
18521955
if plot_uncompressed_amplitude:
18531956
datas += [
18541957
(waveplot_plots, 'waveplot.jpg', waveplot),
@@ -1857,6 +1960,10 @@ def compute_wrapper(
18571960
datas += [
18581961
(compressed_paths, 'compressed.jpg', segments['stft_db']),
18591962
]
1963+
if 'stft_db_origsr' in segments:
1964+
datas += [
1965+
(compressed_paths, 'compressed.origsr.jpg', segments['stft_db_origsr']),
1966+
]
18601967
if 'waveplot' in segments:
18611968
datas += [
18621969
(waveplot_compressed_paths, 'compressed.waveplot.jpg', segments['waveplot']),
@@ -1870,6 +1977,59 @@ def compute_wrapper(
18701977
(masked_paths, 'masked.jpg', masked),
18711978
]
18721979

1980+
# Interpolate waveplots, mask, and masked images to approximately match the original sample rate images
1981+
if include_original_sr:
1982+
if plot_uncompressed_amplitude:
1983+
waveplot_interp = cv2.resize(
1984+
waveplot,
1985+
(stft_db_origsr.shape[1], waveplot.shape[0]),
1986+
interpolation=cv2.INTER_LINEAR,
1987+
)
1988+
datas += [
1989+
(waveplot_plots, 'waveplot.origsr.jpg', waveplot_interp),
1990+
]
1991+
if 'waveplot' in segments:
1992+
waveplot_compressed_interp = cv2.resize(
1993+
segments['waveplot'],
1994+
(segments['stft_db_origsr'].shape[1], segments['waveplot'].shape[0]),
1995+
interpolation=cv2.INTER_LINEAR,
1996+
)
1997+
datas += [
1998+
(
1999+
waveplot_compressed_paths,
2000+
'compressed.waveplot.origsr.jpg',
2001+
waveplot_compressed_interp,
2002+
),
2003+
]
2004+
if 'costs' in segments and 'stft_db' in segments:
2005+
mask_interp = cv2.resize(
2006+
segments['costs'],
2007+
(segments['stft_db_origsr'].shape[1], segments['costs'].shape[0]),
2008+
interpolation=cv2.INTER_LINEAR,
2009+
)
2010+
masked_interp = cv2.resize(
2011+
masked,
2012+
(segments['stft_db_origsr'].shape[1], masked.shape[0]),
2013+
interpolation=cv2.INTER_LINEAR,
2014+
)
2015+
if orig_sr >= sr:
2016+
# Pad mask and masked to account for extra higher frequencies
2017+
mask_interp = np.pad(
2018+
mask_interp, ((stft_db_origsr.shape[0] - mask_interp.shape[0], 0), (0, 0))
2019+
)
2020+
masked_interp = np.pad(
2021+
masked_interp, ((stft_db_origsr.shape[0] - masked_interp.shape[0], 0), (0, 0))
2022+
)
2023+
else:
2024+
# remove higher frequencies from mask which aren't present with original sr
2025+
mask_interp = mask_interp[mask_interp.shape[0] - stft_db_origsr.shape[0] :]
2026+
masked_interp = masked_interp[masked_interp.shape[0] - stft_db_origsr.shape[0] :]
2027+
pass
2028+
datas += [
2029+
(mask_paths, 'mask.origsr.jpg', mask_interp),
2030+
(masked_paths, 'masked.origsr.jpg', masked_interp),
2031+
]
2032+
18732033
for accumulator, tag, data in datas:
18742034
if data.dtype != np.uint8:
18752035
data_ = data.astype(np.float32)
@@ -1926,9 +2086,22 @@ def compute_wrapper(
19262086
'width.px': segments['stft_db'].shape[1],
19272087
'height.px': segments['stft_db'].shape[0],
19282088
}
2089+
if 'stft_db_origsr' in segments:
2090+
metadata['size']['compressed_origsr'] = {
2091+
'width.px': segments['stft_db_origsr'].shape[1],
2092+
'height.px': segments['stft_db_origsr'].shape[0],
2093+
}
2094+
metadata['size']['uncompressed_origsr'] = {
2095+
'width.px': stft_db_origsr.shape[1],
2096+
'height.px': stft_db_origsr.shape[0],
2097+
}
2098+
metadata['metadata_origsr'] = meta_origsr
19292099
if 'costs' in segments and 'stft_db' in segments:
19302100
metadata['size']['mask'] = metadata['size']['compressed']
19312101
metadata['size']['masked'] = metadata['size']['compressed']
2102+
if include_original_sr:
2103+
metadata['size']['mask_origsr'] = metadata['size']['compressed_origsr']
2104+
metadata['size']['masked_origsr'] = metadata['size']['compressed_origsr']
19322105

19332106
metadata_path = f'{out_file_stem}.metadata.json'
19342107
with open(metadata_path, 'w') as metafile:

0 commit comments

Comments
 (0)