Skip to content

Commit dd32f55

Browse files
authored
Fix plotting for complex homozygous alleles (#8)
* fixed plotting for homozyous complex allele * added transparent violinplots for complex alleles
1 parent 846ace3 commit dd32f55

File tree

3 files changed

+68
-31
lines changed

3 files changed

+68
-31
lines changed

src/genotyper/genotyping.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional, Union
33

44
import numpy as np
55
import pandas as pd
@@ -14,23 +14,22 @@
1414
from .plotter import plot_clustering_preds, plot_complex_repeats
1515

1616

17-
def decode_alleles_complex(gmm_out_dict, df):
17+
def decode_alleles_complex(gmm_out_dict: Dict[str, Any], df: pd.DataFrame):
1818
if gmm_out_dict['is_hetero']:
1919
df1 = df.loc[gmm_out_dict['group1']]
2020
df2 = df.loc[gmm_out_dict['group2']]
21-
mediangroup1 = []
21+
mediangroup1: List[int] = []
2222
for col in df1.columns:
2323
if col != 'reverse':
24-
mediangroup1.append(find_nearest(df1[col], np.median(df1[col])))
24+
mediangroup1.append(int(find_nearest(df1[col].values, np.median(df1[col]))))
2525
mediangroup1_cnt = len(gmm_out_dict['group1'])
26-
mediangroup2 = []
26+
mediangroup2: List[int] = []
2727
for col in df2.columns:
2828
if col != 'reverse':
29-
mediangroup2.append(find_nearest(df2[col], np.median(df2[col])))
29+
mediangroup2.append(int(find_nearest(df2[col].values, np.median(df2[col]))))
3030
mediangroup2_cnt = len(gmm_out_dict['group2'])
3131
else:
3232
df1 = df.loc[gmm_out_dict['group1']]
33-
df2 = None
3433
mediangroup1 = []
3534
for col in df1.columns:
3635
if col != 'reverse':
@@ -119,7 +118,7 @@ def store_predictions(gt: Genotype, gt_bc: Optional[Genotype], locus_path: str):
119118
print(f'Allele lengths as given by basecall: {gt_bc.alleles}')
120119

121120

122-
def run_genotyping_complex(locus_path: str, df):
121+
def run_genotyping_complex(locus_path: str, df: Union[pd.DataFrame, None]):
123122
if df is None:
124123
inpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_repeat_units.csv')
125124
if os.path.isfile(inpath):
@@ -148,12 +147,37 @@ def run_genotyping_complex(locus_path: str, df):
148147
out['group2'] = [idx for idx, g in zip(df.index, preds) if g == 1]
149148
out['predictions'] = preds
150149

151-
alleles = decode_alleles_complex(out, df)
152150
if out['predictions'] is not None:
153151
df['allele'] = preds
154152

153+
alleles = decode_alleles_complex(out, df)
154+
if out['predictions'] is not None:
155+
homozygous = False
156+
print('Genotyped complex repeats in 2 alleles:')
157+
for idx, col in enumerate(cols):
158+
print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2][idx]:5}', )
159+
print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2')
160+
161+
outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv')
162+
with open(outpath, 'w') as f:
163+
f.write('unit,allele1_repeats,allele2_repeats\n')
164+
for idx, col in enumerate(cols):
165+
f.write(f'{col},{alleles[0][idx]},{alleles[2][idx]}\n')
166+
else:
167+
homozygous = True
168+
print('Genotyped complex repeats in a homozygous allele:')
169+
for idx, col in enumerate(cols):
170+
print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2]:5}', )
171+
print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2')
172+
173+
outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv')
174+
with open(outpath, 'w') as f:
175+
f.write('unit,allele1_repeats, allele2_repeats\n')
176+
for idx, col in enumerate(cols):
177+
f.write(f'{col},{alleles[0][idx]},{alleles[2]}\n')
178+
155179
img_path = os.path.join(locus_path, tmpl.SUMMARY_SUBDIR, 'complex_genotypes.svg')
156-
plot_complex_repeats(df, cols, alleles, img_path)
180+
plot_complex_repeats(df, cols, alleles, homozygous, img_path)
157181

158182

159183
def run_genotyping(unfilt_vals: List[int]):

src/genotyper/plotter.py

+33-20
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,45 @@
11
from typing import List, Optional, Tuple
22

33
import numpy as np
4+
import pandas as pd
45
import seaborn as sns
56
from matplotlib import pyplot as plt
6-
from matplotlib.collections import PolyCollection
77

88

9-
def plot_complex_repeats(df, cols, alleles, img_path: str):
10-
fig, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols)))
9+
def plot_complex_repeats(
10+
df: pd.DataFrame,
11+
cols: List[str],
12+
alleles: Tuple[List[int], int, List[int], int],
13+
homozygous: bool,
14+
img_path: str
15+
):
16+
_, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols)))
1117
for idx, col in enumerate(cols):
1218
ex_df = df
13-
val1 = alleles[0][idx]
14-
val2 = alleles[2][idx]
15-
name = 'repeat numbers: '+str(val1)+','+str(val2)
16-
ex_df[name] = ''
17-
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
18-
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
19-
plt.setp(axes[idx].collections, alpha=.3)
20-
first = [r for r in axes[idx].get_children() if type(r) == PolyCollection]
21-
c1 = first[0].get_facecolor()[0]
22-
c2 = first[1].get_facecolor()[0]
23-
if val1 != '-':
24-
axes[idx].axhline(y=val1, color=c1, linestyle='--')
25-
if val2 != '-':
26-
axes[idx].axhline(y=val2, color=c2, linestyle='--')
27-
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
28-
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])
29-
axes[idx].get_legend().remove()
19+
20+
if not homozygous:
21+
val1 = alleles[0][idx]
22+
val2 = alleles[2][idx]
23+
name = 'repeat numbers: '+str(val1)+','+str(val2)
24+
ex_df[name] = ''
25+
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
26+
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
27+
plt.setp(axes[idx].collections, alpha=.3)
28+
axes[idx].axhline(y=val1, color='b', linestyle='--')
29+
axes[idx].axhline(y=val2, color='b', linestyle='--')
30+
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
31+
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])
32+
axes[idx].get_legend().remove()
33+
else:
34+
val1 = alleles[0][idx]
35+
name = 'repeat numbers: '+str(val1)+',-'
36+
ex_df[name] = ''
37+
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, orient='vertical',
38+
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
39+
plt.setp(axes[idx].collections, alpha=.3)
40+
axes[idx].axhline(y=val1, color='b', linestyle='--')
41+
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, orient='vertical',
42+
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])
3043

3144
plt.savefig(img_path, bbox_inches='tight', format='svg')
3245
plt.close()

src/schemas/genotype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ def alleles(self):
3535
return (self.first_allele, self.second_allele)
3636

3737

38-
def find_nearest(array: List[int], value: float):
38+
def find_nearest(array: List[int], value: float) -> float:
3939
return array[(np.abs(np.asarray(array) - value)).argmin()]

0 commit comments

Comments
 (0)