Skip to content

Commit fe7b199

Browse files
committed
bugfix: handle the case where mv tags are not added to bam file
Closes #9
1 parent 8380ac6 commit fe7b199

1 file changed

Lines changed: 39 additions & 27 deletions

File tree

bin/plot_signal.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def main():
3131
args = parser.parse_args()
3232

3333
# Load the data generated by pod5_to_df.py
34+
if not os.path.exists(args.csv):
35+
print(f"Error: CSV file {args.csv} not found.")
36+
return
37+
3438
raw_signal_df = pd.read_csv(args.csv)
3539

3640
# Plotting configuration
@@ -39,37 +43,44 @@ def main():
3943
fig, ax = plt.subplots(1, 1, figsize=(args.figwidth, args.figheight))
4044
x_feature, y_feature = 'time', 'signal'
4145

42-
# 1. Unannotated part (ann == -2)
43-
df_neg2 = raw_signal_df[raw_signal_df['ann'] == -2]
44-
if not df_neg2.empty:
45-
sns.scatterplot(data=df_neg2, x=x_feature, y=y_feature, color='blue',
46-
label='unannotated part', s=50, zorder=4, ax=ax)
46+
47+
# Check if 'ann' column exists (requires Dorado move tags)
48+
if 'ann' in raw_signal_df.columns:
49+
# 1. Unannotated part (ann == -2)
50+
df_neg2 = raw_signal_df[raw_signal_df['ann'] == -2]
51+
if not df_neg2.empty:
52+
sns.scatterplot(data=df_neg2, x=x_feature, y=y_feature, color='blue',
53+
label='unannotated part', s=50, zorder=4, ax=ax)
4754

48-
# 2. Trimmed primer/adapter (ann == -1)
49-
df_neg1 = raw_signal_df[raw_signal_df['ann'] == -1]
50-
if not df_neg1.empty:
51-
sns.lineplot(data=df_neg1, x=x_feature, y=y_feature, color='green',
52-
label='trimmed primer and adapter', zorder=2, ax=ax)
55+
# 2. Trimmed primer/adapter (ann == -1)
56+
df_neg1 = raw_signal_df[raw_signal_df['ann'] == -1]
57+
if not df_neg1.empty:
58+
sns.lineplot(data=df_neg1, x=x_feature, y=y_feature, color='green',
59+
label='trimmed primer and adapter', zorder=2, ax=ax)
5360

54-
# 3. Basecalled region (ann is 0 or 1)
55-
df_base = raw_signal_df[raw_signal_df['ann'].isin([0, 1])]
56-
if not df_base.empty:
57-
sns.lineplot(data=df_base, x=x_feature, y=y_feature, color='orange',
58-
label='basecalled region', zorder=3, ax=ax)
61+
# 3. Basecalled region (ann is 0 or 1)
62+
df_base = raw_signal_df[raw_signal_df['ann'].isin([0, 1])]
63+
if not df_base.empty:
64+
sns.lineplot(data=df_base, x=x_feature, y=y_feature, color='orange',
65+
label='basecalled region', zorder=3, ax=ax)
5966

60-
# 4. Poly-A Tail Region (polyA > -1)
61-
# We highlight the region identified by the 'pa' tag anchors
62-
df_polya = raw_signal_df[raw_signal_df['polyA'] > -1]
63-
if not df_polya.empty:
64-
# We use a distinct color (magenta) to show the estimated tail
65-
sns.lineplot(data=df_polya, x=x_feature, y=y_feature, color='magenta',
66-
label='polyA-tail region', zorder=5, linewidth=2, ax=ax)
67+
# 5. Samples that emit bases (ann == 1) - Red Circles
68+
df_emit = raw_signal_df[raw_signal_df['ann'] == 1]
69+
if not df_emit.empty:
70+
sns.scatterplot(data=df_emit, x=x_feature, y=y_feature, color='red',
71+
label='samples that emit bases', s=50, fc="none", ec='red', zorder=6, ax=ax)
72+
else:
73+
# Fallback: Plot raw signal if 'ann' is missing (Move tags were likely missing in BAM)
74+
print(f"Warning: 'ann' column missing in {args.csv}. Plotting raw signal only.")
75+
sns.lineplot(data=raw_signal_df, x=x_feature, y=y_feature, color='grey',
76+
alpha=0.5, label='raw signal (unannotated)', ax=ax)
6777

68-
# 5. Samples that emit bases (ann == 1) - Red Circles
69-
df_emit = raw_signal_df[raw_signal_df['ann'] == 1]
70-
if not df_emit.empty:
71-
sns.scatterplot(data=df_emit, x=x_feature, y=y_feature, color='red',
72-
label='samples that emit bases', s=50, fc="none", ec='red', zorder=6, ax=ax)
78+
# 4. Poly-A Tail Region (independent check for 'polyA' column)
79+
if 'polyA' in raw_signal_df.columns:
80+
df_polya = raw_signal_df[raw_signal_df['polyA'] > -1]
81+
if not df_polya.empty:
82+
sns.lineplot(data=df_polya, x=x_feature, y=y_feature, color='magenta',
83+
label='polyA-tail region', zorder=5, linewidth=2, ax=ax)
7384

7485
ax.set(title=args.title)
7586

@@ -78,6 +89,7 @@ def main():
7889
plt.tight_layout()
7990
plt.savefig(args.output, dpi=300)
8091
plt.close()
92+
print(f"Successfully saved plot to {args.output}")
8193

8294
if __name__ == "__main__":
8395
main()

0 commit comments

Comments
 (0)