Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,11 @@ def read_data(file_name, data_path, format="h5", mode="CC", config={}):
if format == "h5":
data, info = read_das_continuous_data_h5(data_path / file_name, dataset_keys=[])
elif format == "mseed":
data, info = read_mseed(file_name, config=config)
data, info = read_mseed(file_name, config=config, sampling_rate=config.fs)

elif mode == "TM":
if format == "mseed":
data, info = read_mseed(file_name, config=config)
data, info = read_mseed(file_name, config=config, sampling_rate=config.fs)
# data, info = read_mseed_3c(file_name, config=config)
else:
raise ValueError(f"Unknown mode: {mode}")
Expand All @@ -576,6 +576,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
stream += meta
# stream += obspy.read(tmp)
stream = stream.merge(fill_value="latest")
stream.detrend("demean")

## FIXME: HARDCODE for California
if tmp.startswith("s3://ncedc-pds"):
Expand Down Expand Up @@ -603,10 +604,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
if trace.stats.sampling_rate != sampling_rate:
logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz")
try:
trace = trace.interpolate(sampling_rate, method="linear")
if tmp.startswith("s3://ncedc-pds"):
trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
elif tmp.startswith("s3://scedc-pds"):
trace.filter("lowpass", freq=0.45 * sampling_rate, zerophase=True, corners=8)
trace.interpolate(method="lanczos", sampling_rate=sampling_rate, a=1.0)
# trace = trace.interpolate(sampling_rate, method="linear")
if tmp.startswith(("s3://ncedc-pds", "s3://scedc-pds")):
trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
except Exception as e:
print(f"Error resampling {trace.id}:\n{e}")
Expand Down Expand Up @@ -647,10 +648,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
nx = len(station_ids)
nt = max([len(tr.data) for tr in stream])

## FIXME: HARDCODE for California
if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"):
nt = 8640001

# ## FIXME: HARDCODE for California
# if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"):
# nt = 8640001
data = np.zeros([3, nx, nt], dtype=np.float32)
for i, sta in enumerate(station_keys):
for c in station_ids[sta]:
Expand Down
24 changes: 22 additions & 2 deletions cctorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,28 @@ def __init__(
# AN
self.nlag = config.nlag
self.nfft = self.nlag * 2
self.window = torch.hann_window(self.nfft, periodic=False).to(self.device)
# self.window = torch.hann_window(self.nfft, periodic=False).to(self.device)
self.spectral_whitening = config.spectral_whitening

def partial_hann_taper(self, length, taper_fraction=0.04, device="cpu"):
# print('Chris flag taper', length, taper_fraction)
n_taper = int(length * taper_fraction)
if n_taper == 0:
return torch.ones(length, device=device)

# Hann window for edges
x = torch.linspace(0, torch.pi / 2, n_taper, device=device)
taper_edge = torch.sin(x)**2 # sin² taper

taper_start = taper_edge
taper_end = taper_edge.flip(0)


# Build full window: start + flat + end
ones_middle = torch.ones(length - 2 * n_taper, device=device)
window = torch.cat([taper_start, ones_middle, taper_end], dim=0)
return window

def forward(self, x):
"""Perform cross-correlation on input data
Args:
Expand All @@ -54,7 +73,7 @@ def forward(self, x):
- data (torch.Tensor): data2 with shape (batch, nsta/nch, nt)
- info (dict): information information of data2
"""

self.window = self.partial_hann_taper(self.nfft, 0.04, device=self.device)
x1, x2 = x
if self.to_device:
data1 = x1["data"].to(self.device)
Expand Down Expand Up @@ -158,6 +177,7 @@ def forward(self, x):
data2 = torch.exp(1j * data2.angle())

xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1)
xcor = xcor / data1.size(1)
xcor = torch.roll(xcor, self.nlag, dims=-1)
xcor = xcor.view(nb1, nc1, nx1, -1)

Expand Down
4 changes: 3 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import json
import logging
import os
import threading
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from dataclasses import dataclass
Expand Down