|
1 | 1 | import os
|
2 |
| -from typing import List, Optional |
| 2 | +from typing import Any, Dict, List, Optional, Union |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import pandas as pd
|
|
14 | 14 | from .plotter import plot_clustering_preds, plot_complex_repeats
|
15 | 15 |
|
16 | 16 |
|
17 |
| -def decode_alleles_complex(gmm_out_dict, df): |
| 17 | +def decode_alleles_complex(gmm_out_dict: Dict[str, Any], df: pd.DataFrame): |
18 | 18 | if gmm_out_dict['is_hetero']:
|
19 | 19 | df1 = df.loc[gmm_out_dict['group1']]
|
20 | 20 | df2 = df.loc[gmm_out_dict['group2']]
|
21 |
| - mediangroup1 = [] |
| 21 | + mediangroup1: List[int] = [] |
22 | 22 | for col in df1.columns:
|
23 | 23 | if col != 'reverse':
|
24 |
| - mediangroup1.append(find_nearest(df1[col], np.median(df1[col]))) |
| 24 | + mediangroup1.append(int(find_nearest(df1[col].values, np.median(df1[col])))) |
25 | 25 | mediangroup1_cnt = len(gmm_out_dict['group1'])
|
26 |
| - mediangroup2 = [] |
| 26 | + mediangroup2: List[int] = [] |
27 | 27 | for col in df2.columns:
|
28 | 28 | if col != 'reverse':
|
29 |
| - mediangroup2.append(find_nearest(df2[col], np.median(df2[col]))) |
| 29 | + mediangroup2.append(int(find_nearest(df2[col].values, np.median(df2[col])))) |
30 | 30 | mediangroup2_cnt = len(gmm_out_dict['group2'])
|
31 | 31 | else:
|
32 | 32 | df1 = df.loc[gmm_out_dict['group1']]
|
33 |
| - df2 = None |
34 | 33 | mediangroup1 = []
|
35 | 34 | for col in df1.columns:
|
36 | 35 | if col != 'reverse':
|
@@ -119,7 +118,7 @@ def store_predictions(gt: Genotype, gt_bc: Optional[Genotype], locus_path: str):
|
119 | 118 | print(f'Allele lengths as given by basecall: {gt_bc.alleles}')
|
120 | 119 |
|
121 | 120 |
|
122 |
| -def run_genotyping_complex(locus_path: str, df): |
| 121 | +def run_genotyping_complex(locus_path: str, df: Union[pd.DataFrame, None]): |
123 | 122 | if df is None:
|
124 | 123 | inpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_repeat_units.csv')
|
125 | 124 | if os.path.isfile(inpath):
|
@@ -148,12 +147,37 @@ def run_genotyping_complex(locus_path: str, df):
|
148 | 147 | out['group2'] = [idx for idx, g in zip(df.index, preds) if g == 1]
|
149 | 148 | out['predictions'] = preds
|
150 | 149 |
|
151 |
| - alleles = decode_alleles_complex(out, df) |
152 | 150 | if out['predictions'] is not None:
|
153 | 151 | df['allele'] = preds
|
154 | 152 |
|
| 153 | + alleles = decode_alleles_complex(out, df) |
| 154 | + if out['predictions'] is not None: |
| 155 | + homozygous = False |
| 156 | + print('Genotyped complex repeats in 2 alleles:') |
| 157 | + for idx, col in enumerate(cols): |
| 158 | + print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2][idx]:5}', ) |
| 159 | + print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2') |
| 160 | + |
| 161 | + outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv') |
| 162 | + with open(outpath, 'w') as f: |
| 163 | + f.write('unit,allele1_repeats,allele2_repeats\n') |
| 164 | + for idx, col in enumerate(cols): |
| 165 | + f.write(f'{col},{alleles[0][idx]},{alleles[2][idx]}\n') |
| 166 | + else: |
| 167 | + homozygous = True |
| 168 | + print('Genotyped complex repeats in a homozygous allele:') |
| 169 | + for idx, col in enumerate(cols): |
| 170 | + print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2]:5}', ) |
| 171 | + print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2') |
| 172 | + |
| 173 | + outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv') |
| 174 | + with open(outpath, 'w') as f: |
| 175 | + f.write('unit,allele1_repeats, allele2_repeats\n') |
| 176 | + for idx, col in enumerate(cols): |
| 177 | + f.write(f'{col},{alleles[0][idx]},{alleles[2]}\n') |
| 178 | + |
155 | 179 | img_path = os.path.join(locus_path, tmpl.SUMMARY_SUBDIR, 'complex_genotypes.svg')
|
156 |
| - plot_complex_repeats(df, cols, alleles, img_path) |
| 180 | + plot_complex_repeats(df, cols, alleles, homozygous, img_path) |
157 | 181 |
|
158 | 182 |
|
159 | 183 | def run_genotyping(unfilt_vals: List[int]):
|
|
0 commit comments