1111import pandas as pd
1212from collections import defaultdict
1313
14+ from .__init__ import SEARCH_DEPTH_FACTOR , DEFAULT_LINEAGE_RESOLUTION
15+
1416from .assign import assign_query_hdf5
1517from .network import construct_network_from_edge_list , printClusters , save_network
1618from .models import LineageFit
1719from .plot import writeClusterCsv
1820from .sketchlib import readDBParams
1921from .qc import prune_distance_matrix , sketchlibAssemblyQC
20- from .utils import createOverallLineage , get_match_search_depth , readPickle , setupDBFuncs
22+ from .utils import createOverallLineage , readPickle , setupDBFuncs , update_distance_matrices , storePickle
23+
24+ import pp_sketchlib
2125
2226# command line parsing
2327def get_options ():
@@ -114,7 +118,7 @@ def get_options():
114118 help = "Number of kNN distances per sequence to filter when "
115119 "counting neighbours or using only reciprocal matches" ,
116120 type = int ,
117- default = None )
121+ default = 10000 )
118122 lGroup .add_argument ('--use-accessory' ,
119123 help = "Use accessory distances for lineage clustering" ,
120124 action = 'store_true' ,
@@ -130,6 +134,10 @@ def get_options():
130134 help = "Only use reciprocal kNN matches for lineage definitions" ,
131135 action = 'store_true' ,
132136 default = False )
137+ lGroup .add_argument ('--lineage-resolution' ,
138+ help = "Minimum genetic separation between isolates required to initiate a new lineage" ,
139+ type = float ,
140+ default = DEFAULT_LINEAGE_RESOLUTION )
133141
134142 return parser .parse_args ()
135143
@@ -165,15 +173,13 @@ def create_db(args):
165173 else :
166174 clustering_file = args .external_clustering
167175 strains = pd .read_csv (clustering_file , dtype = str ).groupby (args .clustering_col_name )
168-
176+
169177 sys .stderr .write ("Extracting properties of database\n " )
170178 # Get rlist
171179 if args .distances is None :
172180 distances = os .path .join (args .create_db ,os .path .basename (args .create_db ) + ".dists" )
173181 else :
174182 distances = args .distances
175- # Get distances
176- rlist , qlist , self , X = readPickle (distances , enforce_self = False , distances = True )
177183 # Get parameters
178184 kmers , sketch_sizes , codon_phased = readDBParams (args .create_db )
179185 # Ranks to use
@@ -185,9 +191,15 @@ def create_db(args):
185191 else :
186192 max_search_depth = args .max_search_depth
187193 else :
188- max_search_depth = get_match_search_depth (rlist ,rank_list )
194+ # By default retain a larger number of search distances
195+ # than the maximum requested rank because when counting only
196+ # unique distances, and merging distances differing by less
197+ # than epsilon, more than the max rank number of values is
198+ # required
199+ max_search_depth = max (rank_list )* SEARCH_DEPTH_FACTOR
189200
190201 sys .stderr .write ("Generating databases for individual strains\n " )
202+ all_isolates = list ()
191203 # Dicts for storing typing information
192204 lineage_dbs = {}
193205 overall_lineage = {}
@@ -199,6 +211,7 @@ def create_db(args):
199211 num_isolates = len (isolate_list )
200212 if num_isolates >= args .min_count :
201213 lineage_dbs [strain ] = strain_db_name
214+ all_isolates .extend (isolate_list )
202215 if os .path .isdir (strain_db_name ) and args .overwrite :
203216 sys .stderr .write ("--overwrite means {strain_db_name} will be deleted now\n " )
204217 shutil .rmtree (strain_db_name )
@@ -217,27 +230,32 @@ def create_db(args):
217230 shutil .rmtree (dest_db )
218231 elif not os .path .exists (dest_db ):
219232 os .symlink (rel_path ,dest_db )
220- # Extract sparse distances
221- prune_distance_matrix (rlist ,
222- list (set (rlist ) - set (isolate_list )),
223- X ,
224- os .path .join (strain_db_name ,strain_db_name + '.dists' ))
233+ # Store isolate names
234+ storePickle (isolate_list , isolate_list , True , None , os .path .join (strain_db_name ,strain_db_name + '.dists' ))
235+ # Calculate within-strain distances
236+ strain_distMat = pp_sketchlib .queryDatabase (ref_db_name = dest_db .replace ('.h5' ,'' ),
237+ query_db_name = dest_db .replace ('.h5' ,'' ),
238+ rList = isolate_list ,
239+ qList = isolate_list ,
240+ klist = kmers .tolist (),
241+ random_correct = True ,
242+ jaccard = False ,
243+ num_threads = args .threads ,
244+ use_gpu = args .gpu_dist ,
245+ device_id = args .deviceid )
246+
225247 # Initialise model
226248 model = LineageFit (strain_db_name ,
227249 rank_list ,
228250 max_search_depth ,
229251 args .reciprocal_only ,
230252 args .count_unique_distances ,
253+ args .lineage_resolution ,
254+ dist_col = 1 if args .use_accessory else 0 ,
231255 use_gpu = args .gpu_graph )
232256 model .set_threads (args .threads )
233- # Load pruned distance matrix
234- strain_rlist , strain_qlist , strain_self , strain_X = \
235- readPickle (os .path .join (strain_db_name ,strain_db_name + '.dists' ),
236- enforce_self = False ,
237- distances = True )
238257 # Fit model
239- model .fit (strain_X ,
240- args .use_accessory )
258+ model .fit (strain_distMat )
241259 # Lineage fit requires some iteration
242260 indivNetworks = {}
243261 lineage_clusters = defaultdict (dict )
@@ -246,8 +264,8 @@ def create_db(args):
246264 if rank <= num_isolates :
247265 assignments = model .assign (rank )
248266 # Generate networks
249- indivNetworks [rank ] = construct_network_from_edge_list (strain_rlist ,
250- strain_rlist ,
267+ indivNetworks [rank ] = construct_network_from_edge_list (isolate_list ,
268+ isolate_list ,
251269 assignments ,
252270 weights = None ,
253271 betweenness_sample = None ,
@@ -262,7 +280,7 @@ def create_db(args):
262280 # Identify clusters from output
263281 lineage_clusters [rank ] = \
264282 printClusters (indivNetworks [rank ],
265- strain_rlist ,
283+ isolate_list ,
266284 printCSV = False ,
267285 use_gpu = args .gpu_graph )[0 ]
268286 n_clusters = max (lineage_clusters [rank ].values ())
@@ -271,8 +289,8 @@ def create_db(args):
271289 # For each strain, print output of each rank as CSV
272290 overall_lineage [strain ] = createOverallLineage (rank_list , lineage_clusters )
273291 writeClusterCsv (os .path .join (strain_db_name ,os .path .basename (strain_db_name ) + '_lineages.csv' ),
274- strain_rlist ,
275- strain_rlist ,
292+ isolate_list ,
293+ isolate_list ,
276294 overall_lineage [strain ],
277295 output_format = 'phandango' ,
278296 epiCsv = None ,
@@ -282,12 +300,12 @@ def create_db(args):
282300 model .save ()
283301
284302 # Print combined strain and lineage clustering
285- print_overall_clustering (overall_lineage ,args .output + '.csv' ,rlist )
303+ print_overall_clustering (overall_lineage ,args .output + '.csv' ,all_isolates )
286304
287305 # Write scheme to file
288306 with open (args .db_scheme , 'wb' ) as pickle_file :
289307 pickle .dump ([args .create_db ,
290- rlist ,
308+ isolate_list ,
291309 args .model_dir ,
292310 clustering_file ,
293311 args .clustering_col_name ,
0 commit comments