Skip to content

Commit 5879ae1

Browse files
committed
scripts: add logging to plot_signal.py and pod5_to_df.py
Closes #10
1 parent 8b97e01 commit 5879ae1

2 files changed

Lines changed: 200 additions & 122 deletions

File tree

bin/plot_signal.py

Lines changed: 92 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
import matplotlib.pyplot as plt
77
import argparse
88
import os
9+
import logging
10+
import sys
11+
12+
# Configure logging to stderr
13+
logging.basicConfig(
14+
level=logging.INFO,
15+
format='%(asctime)s - %(levelname)s - %(message)s',
16+
stream=sys.stderr
17+
)
918

1019
def main():
1120
parser = argparse.ArgumentParser()
@@ -30,66 +39,97 @@ def main():
3039
)
3140
args = parser.parse_args()
3241

33-
# Load the data generated by pod5_to_df.py
34-
if not os.path.exists(args.csv):
35-
print(f"Error: CSV file {args.csv} not found.")
36-
return
42+
try:
43+
logging.info(f"Starting visualization for CSV: {args.csv}")
44+
45+
# Load the data generated by pod5_to_df.py
46+
if not os.path.exists(args.csv):
47+
logging.error(f"CSV file {args.csv} not found.")
48+
sys.exit(1)
3749

38-
raw_signal_df = pd.read_csv(args.csv)
50+
try:
51+
raw_signal_df = pd.read_csv(args.csv)
52+
logging.info(f"Successfully loaded CSV file with {len(raw_signal_df)} data points")
53+
except Exception as e:
54+
logging.error(f"Failed to load CSV file {args.csv}: {e}")
55+
raise
3956

40-
# Plotting configuration
41-
sns.set(font_scale=1)
42-
sns.set_style("white")
43-
fig, ax = plt.subplots(1, 1, figsize=(args.figwidth, args.figheight))
44-
x_feature, y_feature = 'time', 'signal'
57+
# Plotting configuration
58+
try:
59+
sns.set(font_scale=1)
60+
sns.set_style("white")
61+
fig, ax = plt.subplots(1, 1, figsize=(args.figwidth, args.figheight))
62+
x_feature, y_feature = 'time', 'signal'
63+
logging.info("Initialized plotting figure")
64+
except Exception as e:
65+
logging.error(f"Failed to initialize plot: {e}")
66+
raise
4567

46-
47-
# Check if 'ann' column exists (requires Dorado move tags)
48-
if 'ann' in raw_signal_df.columns:
49-
# 1. Unannotated part (ann == -2)
50-
df_neg2 = raw_signal_df[raw_signal_df['ann'] == -2]
51-
if not df_neg2.empty:
52-
sns.scatterplot(data=df_neg2, x=x_feature, y=y_feature, color='blue',
53-
label='unannotated part', s=50, zorder=4, ax=ax)
68+
69+
# Check if 'ann' column exists (requires Dorado move tags)
70+
try:
71+
if 'ann' in raw_signal_df.columns:
72+
# 1. Unannotated part (ann == -2)
73+
df_neg2 = raw_signal_df[raw_signal_df['ann'] == -2]
74+
if not df_neg2.empty:
75+
sns.scatterplot(data=df_neg2, x=x_feature, y=y_feature, color='blue',
76+
label='unannotated part', s=50, zorder=4, ax=ax)
5477

55-
# 2. Trimmed primer/adapter (ann == -1)
56-
df_neg1 = raw_signal_df[raw_signal_df['ann'] == -1]
57-
if not df_neg1.empty:
58-
sns.lineplot(data=df_neg1, x=x_feature, y=y_feature, color='green',
59-
label='trimmed primer and adapter', zorder=2, ax=ax)
78+
# 2. Trimmed primer/adapter (ann == -1)
79+
df_neg1 = raw_signal_df[raw_signal_df['ann'] == -1]
80+
if not df_neg1.empty:
81+
sns.lineplot(data=df_neg1, x=x_feature, y=y_feature, color='green',
82+
label='trimmed primer and adapter', zorder=2, ax=ax)
6083

61-
# 3. Basecalled region (ann is 0 or 1)
62-
df_base = raw_signal_df[raw_signal_df['ann'].isin([0, 1])]
63-
if not df_base.empty:
64-
sns.lineplot(data=df_base, x=x_feature, y=y_feature, color='orange',
65-
label='basecalled region', zorder=3, ax=ax)
84+
# 3. Basecalled region (ann is 0 or 1)
85+
df_base = raw_signal_df[raw_signal_df['ann'].isin([0, 1])]
86+
if not df_base.empty:
87+
sns.lineplot(data=df_base, x=x_feature, y=y_feature, color='orange',
88+
label='basecalled region', zorder=3, ax=ax)
6689

67-
# 5. Samples that emit bases (ann == 1) - Red Circles
68-
df_emit = raw_signal_df[raw_signal_df['ann'] == 1]
69-
if not df_emit.empty:
70-
sns.scatterplot(data=df_emit, x=x_feature, y=y_feature, color='red',
71-
label='samples that emit bases', s=50, fc="none", ec='red', zorder=6, ax=ax)
72-
else:
73-
# Fallback: Plot raw signal if 'ann' is missing (Move tags were likely missing in BAM)
74-
print(f"Warning: 'ann' column missing in {args.csv}. Plotting raw signal only.")
75-
sns.lineplot(data=raw_signal_df, x=x_feature, y=y_feature, color='grey',
76-
alpha=0.5, label='raw signal (unannotated)', ax=ax)
90+
# 5. Samples that emit bases (ann == 1) - Red Circles
91+
df_emit = raw_signal_df[raw_signal_df['ann'] == 1]
92+
if not df_emit.empty:
93+
sns.scatterplot(data=df_emit, x=x_feature, y=y_feature, color='red',
94+
label='samples that emit bases', s=50, fc="none", ec='red', zorder=6, ax=ax)
95+
logging.info("Plotted annotated signal regions")
96+
else:
97+
# Fallback: Plot raw signal if 'ann' is missing (Move tags were likely missing in BAM)
98+
logging.warning(f"'ann' column missing in {args.csv}. Plotting raw signal only.")
99+
sns.lineplot(data=raw_signal_df, x=x_feature, y=y_feature, color='grey',
100+
alpha=0.5, label='raw signal (unannotated)', ax=ax)
101+
except Exception as e:
102+
logging.error(f"Error plotting signal regions: {e}")
103+
raise
77104

78-
# 4. Poly-A Tail Region (independent check for 'polyA' column)
79-
if 'polyA' in raw_signal_df.columns:
80-
df_polya = raw_signal_df[raw_signal_df['polyA'] > -1]
81-
if not df_polya.empty:
82-
sns.lineplot(data=df_polya, x=x_feature, y=y_feature, color='magenta',
83-
label='polyA-tail region', zorder=5, linewidth=2, ax=ax)
105+
# 4. Poly-A Tail Region (independent check for 'polyA' column)
106+
try:
107+
if 'polyA' in raw_signal_df.columns:
108+
df_polya = raw_signal_df[raw_signal_df['polyA'] > -1]
109+
if not df_polya.empty:
110+
sns.lineplot(data=df_polya, x=x_feature, y=y_feature, color='magenta',
111+
label='polyA-tail region', zorder=5, linewidth=2, ax=ax)
112+
logging.info("Plotted Poly-A tail region")
113+
except Exception as e:
114+
logging.error(f"Error plotting Poly-A tail region: {e}")
115+
raise
84116

85-
ax.set(title=args.title)
86-
87-
# Legend placement
88-
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
89-
plt.tight_layout()
90-
plt.savefig(args.output, dpi=300)
91-
plt.close()
92-
print(f"Successfully saved plot to {args.output}")
117+
try:
118+
ax.set(title=args.title)
119+
120+
# Legend placement
121+
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
122+
plt.tight_layout()
123+
plt.savefig(args.output, dpi=300)
124+
plt.close()
125+
logging.info(f"Successfully saved plot to {args.output}")
126+
except Exception as e:
127+
logging.error(f"Failed to save plot to {args.output}: {e}")
128+
raise
129+
130+
except Exception as e:
131+
logging.error(f"Fatal error in visualization: {e}")
132+
sys.exit(1)
93133

94134
if __name__ == "__main__":
95135
main()

bin/pod5_to_df.py

Lines changed: 108 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
import numpy as np
55
import pysam
66
import argparse
7+
import logging
8+
import sys
9+
10+
# Configure logging to stderr
11+
logging.basicConfig(
12+
level=logging.INFO,
13+
format='%(asctime)s - %(levelname)s - %(message)s',
14+
stream=sys.stderr
15+
)
716

817
def main():
918
parser = argparse.ArgumentParser()
@@ -12,85 +21,114 @@ def main():
1221
parser.add_argument('--sample_id', required=True)
1322
args = parser.parse_args()
1423

15-
# 1. Load BAM data including Poly-A estimation tags
16-
bam_data = {}
17-
with pysam.AlignmentFile(args.bam, "rb", check_sq=False) as bam:
18-
for read in bam.fetch(until_eof=True):
19-
bam_data[read.query_name] = {
20-
'seq': read.query_sequence,
21-
'mv': read.get_tag("mv") if read.has_tag("mv") else None,
22-
'ts': read.get_tag("ts") if read.has_tag("ts") else 0,
23-
'ns': read.get_tag("ns") if read.has_tag("ns") else 0,
24-
'pt': read.get_tag("pt") if read.has_tag("pt") else 0, # Poly-A length
25-
'pa': read.get_tag("pa") if read.has_tag("pa") else None # Signal anchors
26-
}
24+
try:
25+
logging.info(f"Starting processing for sample: {args.sample_id}")
26+
logging.info(f"Loading BAM file: {args.bam}")
27+
28+
# 1. Load BAM data including Poly-A estimation tags
29+
bam_data = {}
30+
try:
31+
with pysam.AlignmentFile(args.bam, "rb", check_sq=False) as bam:
32+
for read in bam.fetch(until_eof=True):
33+
bam_data[read.query_name] = {
34+
'seq': read.query_sequence,
35+
'mv': read.get_tag("mv") if read.has_tag("mv") else None,
36+
'ts': read.get_tag("ts") if read.has_tag("ts") else 0,
37+
'ns': read.get_tag("ns") if read.has_tag("ns") else 0,
38+
'pt': read.get_tag("pt") if read.has_tag("pt") else 0, # Poly-A length
39+
'pa': read.get_tag("pa") if read.has_tag("pa") else None # Signal anchors
40+
}
41+
logging.info(f"Successfully loaded {len(bam_data)} reads from BAM file")
42+
except Exception as e:
43+
logging.error(f"Failed to load BAM file {args.bam}: {e}")
44+
raise
2745

28-
with p5.Reader(args.pod5) as reader:
29-
for read_record in reader.reads():
30-
read_id = str(read_record.read_id)
31-
if read_id not in bam_data:
32-
continue
46+
logging.info(f"Processing POD5 file: {args.pod5}")
47+
processed_reads = 0
48+
49+
try:
50+
with p5.Reader(args.pod5) as reader:
51+
for read_record in reader.reads():
52+
read_id = str(read_record.read_id)
53+
if read_id not in bam_data:
54+
continue
3355

34-
# Signal extraction
35-
signal = read_record.signal
36-
sample_rate = read_record.run_info.sample_rate
37-
time = np.arange(len(signal)) / sample_rate
38-
raw_signal_df = pd.DataFrame({'time': time, 'signal': signal})
56+
try:
57+
# Signal extraction
58+
signal = read_record.signal
59+
sample_rate = read_record.run_info.sample_rate
60+
time = np.arange(len(signal)) / sample_rate
61+
raw_signal_df = pd.DataFrame({'time': time, 'signal': signal})
3962

40-
info = bam_data[read_id]
41-
42-
# Initialize polyA column (default to -1 for non-tail region)
43-
raw_signal_df['polyA'] = -1
63+
info = bam_data[read_id]
64+
65+
# Initialize polyA column (default to -1 for non-tail region)
66+
raw_signal_df['polyA'] = -1
4467

45-
# 2. Map Poly-A Region if tags exist
46-
if info['pa'] is not None:
47-
# pa array indices: 1 = start of polyA, 2 = end of polyA
48-
pa_start = info['pa'][1]
49-
pa_end = info['pa'][2]
50-
51-
# Fill the polyA column with the estimated length (pt) for that region
52-
# This makes it easy to filter/color the tail in plots
53-
raw_signal_df.loc[pa_start:pa_end, 'polyA'] = info['pt']
68+
# 2. Map Poly-A Region if tags exist
69+
if info['pa'] is not None:
70+
# pa array indices: 1 = start of polyA, 2 = end of polyA
71+
pa_start = info['pa'][1]
72+
pa_end = info['pa'][2]
73+
74+
# Fill the polyA column with the estimated length (pt) for that region
75+
# This makes it easy to filter/color the tail in plots
76+
raw_signal_df.loc[pa_start:pa_end, 'polyA'] = info['pt']
5477

55-
# 3. Handle Moves and Base Mapping
56-
if info['mv'] is not None:
57-
stride = info['mv'][0]
58-
moves = info['mv'][1:]
59-
sequence = info['seq']
60-
61-
# Build 'ann' (moves) array
62-
a = info['ts'] * [-1]
63-
for elem in moves:
64-
a += [elem] * stride
65-
a += (len(raw_signal_df) - info['ns']) * [-1]
66-
a = [-2] * max((len(raw_signal_df) - len(a)), 0) + a
67-
68-
# Trim or pad 'a' to match signal length exactly
69-
if len(a) > len(raw_signal_df):
70-
a = a[:len(raw_signal_df)]
71-
else:
72-
a += [-1] * (len(raw_signal_df) - len(a))
73-
74-
raw_signal_df['ann'] = a
78+
# 3. Handle Moves and Base Mapping
79+
if info['mv'] is not None:
80+
stride = info['mv'][0]
81+
moves = info['mv'][1:]
82+
sequence = info['seq']
83+
84+
# Build 'ann' (moves) array
85+
a = info['ts'] * [-1]
86+
for elem in moves:
87+
a += [elem] * stride
88+
a += (len(raw_signal_df) - info['ns']) * [-1]
89+
a = [-2] * max((len(raw_signal_df) - len(a)), 0) + a
90+
91+
# Trim or pad 'a' to match signal length exactly
92+
if len(a) > len(raw_signal_df):
93+
a = a[:len(raw_signal_df)]
94+
else:
95+
a += [-1] * (len(raw_signal_df) - len(a))
96+
97+
raw_signal_df['ann'] = a
7598

76-
# 4. Map Nucleotides to Signal
77-
base_labels = ['N'] * len(raw_signal_df)
78-
seq_idx = 0
79-
seq_len = len(sequence)
99+
# 4. Map Nucleotides to Signal
100+
base_labels = ['N'] * len(raw_signal_df)
101+
seq_idx = 0
102+
seq_len = len(sequence)
80103

81-
for i, val in enumerate(a):
82-
if val == 1 and seq_idx < seq_len:
83-
current_base = sequence[seq_idx]
84-
base_labels[i] = current_base
85-
seq_idx += 1
86-
elif val == 0 and seq_idx > 0:
87-
base_labels[i] = sequence[seq_idx - 1]
104+
for i, val in enumerate(a):
105+
if val == 1 and seq_idx < seq_len:
106+
current_base = sequence[seq_idx]
107+
base_labels[i] = current_base
108+
seq_idx += 1
109+
elif val == 0 and seq_idx > 0:
110+
base_labels[i] = sequence[seq_idx - 1]
88111

89-
raw_signal_df['base'] = base_labels
112+
raw_signal_df['base'] = base_labels
90113

91-
# Save annotated CSV
92-
output_name = f"{args.sample_id}_{read_id}_mapped.csv"
93-
raw_signal_df.to_csv(output_name, index=False)
114+
# Save annotated CSV
115+
output_name = f"{args.sample_id}_{read_id}_mapped.csv"
116+
raw_signal_df.to_csv(output_name, index=False)
117+
processed_reads += 1
118+
logging.info(f"Successfully processed read {read_id} -> {output_name}")
119+
except Exception as e:
120+
logging.error(f"Error processing read {read_id}: {e}")
121+
raise
122+
123+
logging.info(f"Successfully processed {processed_reads} reads from POD5 file")
124+
except Exception as e:
125+
logging.error(f"Failed to process POD5 file {args.pod5}: {e}")
126+
raise
127+
128+
logging.info(f"Completed processing for sample: {args.sample_id}")
129+
except Exception as e:
130+
logging.error(f"Fatal error in main processing: {e}")
131+
sys.exit(1)
94132

95133
if __name__ == "__main__":
96134
main()

0 commit comments

Comments
 (0)