Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,9 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
meta = obspy.read(fs, format="MSEED")
stream += meta
# stream += obspy.read(tmp)
stream = stream.merge(fill_value="latest")

stream_mask = stream.copy().merge(fill_value=None)
stream = stream.merge(fill_value=0)

## FIXME: HARDCODE for California
if tmp.startswith("s3://ncedc-pds"):
Expand All @@ -605,12 +607,14 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
begin_time = obspy.UTCDateTime(year=year, julday=jday)
end_time = begin_time + 86400 ## 1 day
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None, nearest_sample=True)
elif tmp.startswith("s3://scedc-pds"):
year_jday = tmp.split("/")[-1].rstrip(".ms")[-7:]
year, jday = int(year_jday[:4]), int(year_jday[4:])
begin_time = obspy.UTCDateTime(year=year, julday=jday)
end_time = begin_time + 86400 ## 1 day
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None, nearest_sample=True)
except Exception as e:
print(f"Error reading {fname}:\n{e}")
return None
Expand Down Expand Up @@ -655,6 +659,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
begin_time = min([st.stats.starttime for st in stream])
end_time = max([st.stats.endtime for st in stream])
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None)

comp = ["3", "2", "1", "E", "N", "Z"]
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
Expand All @@ -675,6 +680,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
nt = int(24 * 60 * 60 * sampling_rate) + 1

data = np.zeros([3, nx, nt], dtype=np.float32)
mask = np.zeros([3, nx, nt], dtype=np.int8)
for i, sta in enumerate(station_keys):
for c in station_ids[sta]:
j = comp2idx[c]
Expand All @@ -684,19 +690,26 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
continue

trace = stream.select(id=sta + c)[0]
trace_mask = stream_mask.select(id=sta + c)[0]
try:
mask_array = trace_mask.data.mask
mask_array = mask_array.astype(int)
except:
mask_array = np.zeros(len(trace_mask.data))

## accerleration to velocity
if sta[-1] == "N":
trace = trace.integrate().filter("highpass", freq=1.0)

tmp = trace.data.astype("float32")
data[j, i, : len(tmp)] = tmp[:nt]
mask[j, i, : len(mask_array)] = mask_array[:nt]

# return data, {
# "begin_time": begin_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# "end_time": end_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# }
return data, {
return data, {"mask": mask,
"begin_time": np.datetime64(begin_time.datetime),
"end_time": np.datetime64(end_time.datetime),
}
Expand Down
73 changes: 55 additions & 18 deletions cctorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from tqdm import tqdm

from .utils import partial_hann_taper, custom_demeaned_stft, cosine_taper_4freq

class CCModel(nn.Module):
def __init__(
Expand Down Expand Up @@ -40,7 +41,8 @@ def __init__(
# AN
self.nlag = config.nlag
self.nfft = self.nlag * 2
self.window = torch.hann_window(self.nfft, periodic=False).to(self.device)
# self.window = torch.hann_window(self.nfft, periodic=False).to(self.device)
self.window = partial_hann_taper(self.nfft, 0.04, self.device)
self.spectral_whitening = config.spectral_whitening

def forward(self, x):
Expand Down Expand Up @@ -127,37 +129,72 @@ def forward(self, x):
xcor = torch.mean(xcor, dim=(-3), keepdim=True)

elif self.domain == "stft":
overlap_ratio = 0.5
hop_length = int(self.nlag * ((1-overlap_ratio)/0.5))

mask1 = x1['info']['mask']
mask2 = x2['info']['mask']
mask1 = torch.from_numpy(np.stack(mask1, axis=0)).float()
mask2 = torch.from_numpy(np.stack(mask2, axis=0)).float()

pooled_mask1 = F.max_pool2d(
mask1,
kernel_size=(1, self.nlag * 2 + 5),
stride=(1, hop_length),
)
pooled_mask2 = F.max_pool2d(
mask2,
kernel_size=(1, self.nlag * 2 + 5),
stride=(1, hop_length),
)
mask_reshaped1 = pooled_mask1.view(pooled_mask1.shape[0]*pooled_mask1.shape[1], 1, pooled_mask1.shape[-1])
mask_reshaped2 = pooled_mask2.view(pooled_mask2.shape[0]*pooled_mask2.shape[1], 1, pooled_mask2.shape[-1])

mask_reshaped1 = 1 - mask_reshaped1
mask_reshaped2 = 1 - mask_reshaped2

nlag = self.nlag
nb1, nc1, nx1, nt1 = data1.shape
# nb2, nc2, nx2, nt2 = data2.shape
data1 = data1.view(nb1 * nc1 * nx1, nt1)
# data2 = data2.view(nb2 * nc2 * nx2, nt2)
data2 = data2.view(nb1 * nc1 * nx1, nt1)
if not self.pre_fft:
data1 = torch.stft(
data1,
n_fft=self.nlag * 2,
hop_length=self.nlag,
window=self.window,
center=True,
return_complex=True,
)
data2 = torch.stft(
data2,
n_fft=self.nlag * 2,
hop_length=self.nlag,
window=self.window,
center=True,
return_complex=True,
)

data1 = custom_demeaned_stft(data1, nlag, hop_length, self.window)
data2 = custom_demeaned_stft(data2, nlag, hop_length, self.window)


# data1 = torch.stft(
# data1,
# n_fft=self.nlag * 2,
# hop_length=self.nlag,
# window=self.window,
# center=True,
# return_complex=True,
# )
# data2 = torch.stft(
# data2,
# n_fft=self.nlag * 2,
# hop_length=self.nlag,
# window=self.window,
# center=True,
# return_complex=True,
# )
if self.spectral_whitening:
# freqs = np.fft.fftfreq(self.nlag*2, d=self.dt)
# data1 = data1 / torch.clip(torch.abs(data1), min=1e-7) #float32 eps
# data2 = data2 / torch.clip(torch.abs(data2), min=1e-7)
data1 = torch.exp(1j * data1.angle())
data2 = torch.exp(1j * data2.angle())

xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1)
f_taper_asym = cosine_taper_4freq(data1.shape[1], low=0.01, high=9.8)
data1 = data1 * f_taper_asym
data2 = data2 * f_taper_asym

# xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1)
xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2) * mask_reshaped1 * mask_reshaped2, dim=-1), n=(self.nlag * 2 + 1),dim=-1)
xcor = xcor / data1.size(1)
xcor = torch.roll(xcor, self.nlag, dims=-1)
xcor = xcor.view(nb1, nc1, nx1, -1)

Expand Down
100 changes: 100 additions & 0 deletions cctorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,106 @@ def write_h5(fn, dataset_name, data, attrs_dict):
fid[dataset_name].attrs.modify(key, val)


def partial_hann_taper(length, taper_fraction=0.04, device="cpu"):
n_taper = int(length * taper_fraction)
if n_taper == 0:
return torch.ones(length, device=device)

# Hann window for edges
x = torch.linspace(0, torch.pi / 2, n_taper, device=device)
taper_edge = torch.sin(x)**2 # sin² taper

taper_start = taper_edge
taper_end = taper_edge.flip(0)
taper_tail = torch.zeros(5, device=device)


# Build full window: start + flat + end
ones_middle = torch.ones(length - 2 * n_taper, device=device)
window = torch.cat([taper_start, ones_middle, taper_end, taper_tail], dim=0)
return window

def custom_demeaned_stft(data1, nlag, hop_length, window):
"""
Custom STFT with per-window demeaning that matches torch.stft(..., center=False)

Args:
data1: (B, T) time-domain signal
nlag: for computing n_fft = 2 * nlag + 5
hop_length: step size between windows
window: (n_fft,) window function (e.g., Hann)

Returns:
Complex STFT of shape (B, freq_bins, time_frames), matching torch.stft
"""
n_fft = 2 * nlag + 5
B, T = data1.shape

# Compute number of complete frames (no padding)
num_frames = (T - n_fft) // hop_length + 1

# Use unfold to extract frames
frames = data1.unfold(dimension=-1, size=n_fft, step=hop_length) # (B, num_frames, n_fft)

# Demean each frame
frames = frames - frames.mean(dim=-1, keepdim=True)

# Apply window
window = window.to(data1.device)
frames = frames * window.view(1, 1, -1)

# Apply FFT
stft_result = torch.fft.rfft(frames, dim=-1) # (B, num_frames, freq_bins)

# Transpose to match torch.stft output: (B, freq_bins, time_frames)
stft_result = stft_result.transpose(-1, -2)

return stft_result # shape: (B, freq_bins, time_frames)

def cosine_taper_4freq(n_freqs, low, high, sample_rate=20):
"""
Create a 1D cosine taper with flat region between left_end and right_start,
and cosine transitions on both sides.

Parameters:
- n_freqs: total number of frequency bins
- left_start, left_end, right_start, right_end: index positions in frequency domain

Returns:
- taper: tensor of shape [n_freqs]
"""
delta_f = sample_rate / ((n_freqs - 1)*2 + 1)
low_idx = math.ceil(low / delta_f)
high_idx = math.floor(high / delta_f)
low_left = low_idx - 100
if low_left < 0:
low_left = 0
high_right = high_idx + 100
if high_right > (n_freqs - 1)*2:
high_right = (n_freqs - 1)*2
# print(f"Doing the classic Brutal Whiten {n_freqs} {low_left} {low_idx} {high_idx} {high_right}")
left_start = low_left
left_end = low_idx
right_start = high_idx
right_end = high_right

taper = np.zeros(n_freqs)

# Left cosine ramp
for i in range(left_start, left_end):
frac = (i - left_start) / (left_end - left_start)
taper[i] = 0.5 * (1 - np.cos(np.pi * frac))

# Flat part
taper[left_end:right_start] = 1.0

# Right cosine ramp
for i in range(right_start, right_end):
frac = (i - right_start) / (right_end - right_start)
taper[i] = 0.5 * (1 + np.cos(np.pi * frac))
cos_taper = torch.tensor(taper, dtype=torch.float32)
return cos_taper[None, :, None]

# # %%
# @dataclass
# class Config:
Expand Down
Binary file added examples/california/.mseeds1_2005_123.txt.swp
Binary file not shown.
Loading