Skip to content

Commit fe8c548

Browse files
authored
Merge pull request #365 from bacpop/lineage_fixes
Fixes to lineage fitting and visualisation
2 parents 1338e46 + acb4e9a commit fe8c548

16 files changed

Lines changed: 207 additions & 142 deletions

.github/workflows/azure_ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030
micromamba-version: '1.4.6-0'
3131
environment-file: environment.yml
3232
# persist on the same day.
33-
cache-environment-key: environment-${{ steps.date.outputs.date }}
34-
cache-downloads-key: downloads-${{ steps.date.outputs.date }}
33+
# cache-environment-key: environment-${{ steps.date.outputs.date }}
34+
# cache-downloads-key: downloads-${{ steps.date.outputs.date }}
3535
- name: Install and run_test.py
3636
shell: bash -l {0}
3737
run: |

PopPUNK/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33

44
'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''
55

6-
__version__ = '2.7.7'
6+
__version__ = '2.7.8'
77

88
# Minimum sketchlib version
99
SKETCHLIB_MAJOR = 2
1010
SKETCHLIB_MINOR = 0
1111
SKETCHLIB_PATCH = 1
12+
13+
# Lineage search depth default
14+
SEARCH_DEPTH_FACTOR = 10
15+
DEFAULT_LINEAGE_RESOLUTION = 1e-10

PopPUNK/__main__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# import poppunk package
1313
from .__init__ import __version__
14+
from .__init__ import SEARCH_DEPTH_FACTOR, DEFAULT_LINEAGE_RESOLUTION
1415

1516
# globals
1617
accepted_weights_types = ["core", "accessory", "euclidean"]
@@ -190,7 +191,7 @@ def get_options():
190191
help='Number of kNN distances per sequence to filter when '
191192
'counting neighbours or using only reciprocal matches',
192193
type = int,
193-
default = None)
194+
default = 10000)
194195
lineagesGroup.add_argument('--write-lineage-networks',
195196
help='Save all lineage networks',
196197
action = 'store_true',
@@ -199,6 +200,10 @@ def get_options():
199200
help='Use accessory distances for lineage definitions [default = use core distances]',
200201
action = 'store_true',
201202
default = False)
203+
lineagesGroup.add_argument('--lineage-resolution',
204+
help='Minimum genetic separation between isolates required to initiate a new lineage',
205+
type = float,
206+
default = DEFAULT_LINEAGE_RESOLUTION)
202207

203208
other = parser.add_argument_group('Other options')
204209
other.add_argument('--threads', default=1, type=int, help='Number of threads to use [default = 1]')
@@ -273,7 +278,6 @@ def main():
273278
from .utils import setupDBFuncs
274279
from .utils import readPickle, storePickle
275280
from .utils import createOverallLineage
276-
from .utils import get_match_search_depth
277281
from .utils import check_and_set_gpu
278282

279283
# check kmer properties
@@ -568,21 +572,24 @@ def main():
568572
# Memory usage determined by maximum search depth
569573
if args.max_search_depth is not None:
570574
max_search_depth = int(args.max_search_depth)
571-
elif args.max_search_depth is None and (args.reciprocal_only or args.count_unique_distances):
572-
max_search_depth = get_match_search_depth(refList,rank_list)
573575
else:
574-
max_search_depth = max(rank_list)
576+
# By default retain a larger number of search distances
577+
# than the maximum requested rank because when counting only
578+
# unique distances, and merging distances differing by less
579+
# than epsilon, more than the max rank number of values is
580+
# required
581+
max_search_depth = max(rank_list)*SEARCH_DEPTH_FACTOR
575582

576583
model = LineageFit(output,
577584
rank_list,
578585
max_search_depth,
579586
args.reciprocal_only,
580587
args.count_unique_distances,
588+
args.lineage_resolution,
581589
1 if args.use_accessory else 0,
582590
use_gpu = args.gpu_graph)
583591
model.set_threads(args.threads)
584-
model.fit(distMat,
585-
args.use_accessory)
592+
model.fit(distMat)
586593

587594
assignments = {}
588595
for rank in rank_list:

PopPUNK/assign.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ def get_options():
128128
# combine
129129
args = parser.parse_args()
130130

131-
# ensure directories do not have trailing forward slash
132-
for arg in [args.db, args.model_dir, args.output, args.previous_clustering]:
133-
if arg is not None:
134-
arg = arg.rstrip('\\')
131+
# ensure directories do not have trailing slash
132+
for attr_name in ['db', 'model_dir', 'output', 'previous_clustering']:
133+
attr_value = getattr(args, attr_name)
134+
if attr_value is not None:
135+
setattr(args, attr_name, attr_value.rstrip('\\').rstrip('/'))
135136

136137
return args
137138

@@ -275,7 +276,6 @@ def assign_query(dbFuncs,
275276
createDatabaseDir = dbFuncs['createDatabaseDir']
276277
constructDatabase = dbFuncs['constructDatabase']
277278
readDBParams = dbFuncs['readDBParams']
278-
279279
if ref_db == output and overwrite == False:
280280
sys.stderr.write("--output and --db must be different to "
281281
"prevent overwrite.\n")

PopPUNK/lineages.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
import pandas as pd
1212
from collections import defaultdict
1313

14+
from .__init__ import SEARCH_DEPTH_FACTOR, DEFAULT_LINEAGE_RESOLUTION
15+
1416
from .assign import assign_query_hdf5
1517
from .network import construct_network_from_edge_list, printClusters, save_network
1618
from .models import LineageFit
1719
from .plot import writeClusterCsv
1820
from .sketchlib import readDBParams
1921
from .qc import prune_distance_matrix, sketchlibAssemblyQC
20-
from .utils import createOverallLineage, get_match_search_depth, readPickle, setupDBFuncs
22+
from .utils import createOverallLineage, readPickle, setupDBFuncs, update_distance_matrices, storePickle
23+
24+
import pp_sketchlib
2125

2226
# command line parsing
2327
def get_options():
@@ -114,7 +118,7 @@ def get_options():
114118
help="Number of kNN distances per sequence to filter when "
115119
"counting neighbours or using only reciprocal matches",
116120
type = int,
117-
default = None)
121+
default = 10000)
118122
lGroup.add_argument('--use-accessory',
119123
help="Use accessory distances for lineage clustering",
120124
action = 'store_true',
@@ -130,6 +134,10 @@ def get_options():
130134
help="Only use reciprocal kNN matches for lineage definitions",
131135
action = 'store_true',
132136
default = False)
137+
lGroup.add_argument('--lineage-resolution',
138+
help="Minimum genetic separation between isolates required to initiate a new lineage",
139+
type = float,
140+
default = DEFAULT_LINEAGE_RESOLUTION)
133141

134142
return parser.parse_args()
135143

@@ -165,15 +173,13 @@ def create_db(args):
165173
else:
166174
clustering_file = args.external_clustering
167175
strains = pd.read_csv(clustering_file, dtype = str).groupby(args.clustering_col_name)
168-
176+
169177
sys.stderr.write("Extracting properties of database\n")
170178
# Get rlist
171179
if args.distances is None:
172180
distances = os.path.join(args.create_db,os.path.basename(args.create_db) + ".dists")
173181
else:
174182
distances = args.distances
175-
# Get distances
176-
rlist, qlist, self, X = readPickle(distances, enforce_self=False, distances=True)
177183
# Get parameters
178184
kmers, sketch_sizes, codon_phased = readDBParams(args.create_db)
179185
# Ranks to use
@@ -185,9 +191,15 @@ def create_db(args):
185191
else:
186192
max_search_depth = args.max_search_depth
187193
else:
188-
max_search_depth = get_match_search_depth(rlist,rank_list)
194+
# By default retain a larger number of search distances
195+
# than the maximum requested rank because when counting only
196+
# unique distances, and merging distances differing by less
197+
# than epsilon, more than the max rank number of values is
198+
# required
199+
max_search_depth = max(rank_list)*SEARCH_DEPTH_FACTOR
189200

190201
sys.stderr.write("Generating databases for individual strains\n")
202+
all_isolates = list()
191203
# Dicts for storing typing information
192204
lineage_dbs = {}
193205
overall_lineage = {}
@@ -199,6 +211,7 @@ def create_db(args):
199211
num_isolates = len(isolate_list)
200212
if num_isolates >= args.min_count:
201213
lineage_dbs[strain] = strain_db_name
214+
all_isolates.extend(isolate_list)
202215
if os.path.isdir(strain_db_name) and args.overwrite:
203216
sys.stderr.write("--overwrite means {strain_db_name} will be deleted now\n")
204217
shutil.rmtree(strain_db_name)
@@ -217,27 +230,32 @@ def create_db(args):
217230
shutil.rmtree(dest_db)
218231
elif not os.path.exists(dest_db):
219232
os.symlink(rel_path,dest_db)
220-
# Extract sparse distances
221-
prune_distance_matrix(rlist,
222-
list(set(rlist) - set(isolate_list)),
223-
X,
224-
os.path.join(strain_db_name,strain_db_name + '.dists'))
233+
# Store isolate names
234+
storePickle(isolate_list, isolate_list, True, None, os.path.join(strain_db_name,strain_db_name + '.dists'))
235+
# Calculate within-strain distances
236+
strain_distMat = pp_sketchlib.queryDatabase(ref_db_name=dest_db.replace('.h5',''),
237+
query_db_name=dest_db.replace('.h5',''),
238+
rList=isolate_list,
239+
qList=isolate_list,
240+
klist=kmers.tolist(),
241+
random_correct=True,
242+
jaccard=False,
243+
num_threads=args.threads,
244+
use_gpu = args.gpu_dist,
245+
device_id = args.deviceid)
246+
225247
# Initialise model
226248
model = LineageFit(strain_db_name,
227249
rank_list,
228250
max_search_depth,
229251
args.reciprocal_only,
230252
args.count_unique_distances,
253+
args.lineage_resolution,
254+
dist_col = 1 if args.use_accessory else 0,
231255
use_gpu = args.gpu_graph)
232256
model.set_threads(args.threads)
233-
# Load pruned distance matrix
234-
strain_rlist, strain_qlist, strain_self, strain_X = \
235-
readPickle(os.path.join(strain_db_name,strain_db_name + '.dists'),
236-
enforce_self=False,
237-
distances=True)
238257
# Fit model
239-
model.fit(strain_X,
240-
args.use_accessory)
258+
model.fit(strain_distMat)
241259
# Lineage fit requires some iteration
242260
indivNetworks = {}
243261
lineage_clusters = defaultdict(dict)
@@ -246,8 +264,8 @@ def create_db(args):
246264
if rank <= num_isolates:
247265
assignments = model.assign(rank)
248266
# Generate networks
249-
indivNetworks[rank] = construct_network_from_edge_list(strain_rlist,
250-
strain_rlist,
267+
indivNetworks[rank] = construct_network_from_edge_list(isolate_list,
268+
isolate_list,
251269
assignments,
252270
weights = None,
253271
betweenness_sample = None,
@@ -262,7 +280,7 @@ def create_db(args):
262280
# Identify clusters from output
263281
lineage_clusters[rank] = \
264282
printClusters(indivNetworks[rank],
265-
strain_rlist,
283+
isolate_list,
266284
printCSV = False,
267285
use_gpu = args.gpu_graph)[0]
268286
n_clusters = max(lineage_clusters[rank].values())
@@ -271,8 +289,8 @@ def create_db(args):
271289
# For each strain, print output of each rank as CSV
272290
overall_lineage[strain] = createOverallLineage(rank_list, lineage_clusters)
273291
writeClusterCsv(os.path.join(strain_db_name,os.path.basename(strain_db_name) + '_lineages.csv'),
274-
strain_rlist,
275-
strain_rlist,
292+
isolate_list,
293+
isolate_list,
276294
overall_lineage[strain],
277295
output_format = 'phandango',
278296
epiCsv = None,
@@ -282,12 +300,12 @@ def create_db(args):
282300
model.save()
283301

284302
# Print combined strain and lineage clustering
285-
print_overall_clustering(overall_lineage,args.output + '.csv',rlist)
303+
print_overall_clustering(overall_lineage,args.output + '.csv',all_isolates)
286304

287305
# Write scheme to file
288306
with open(args.db_scheme, 'wb') as pickle_file:
289307
pickle.dump([args.create_db,
290-
rlist,
308+
isolate_list,
291309
args.model_dir,
292310
clustering_file,
293311
args.clustering_col_name,

PopPUNK/models.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,13 @@ class LineageFit(ClusterFit):
11201120
The ranks used in the fit
11211121
'''
11221122

1123-
def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only, count_unique_distances, dist_col = None, use_gpu = False):
1123+
def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only,
1124+
count_unique_distances, lineage_resolution, dist_col = None, use_gpu = False):
11241125
ClusterFit.__init__(self, outPrefix)
11251126
self.type = 'lineage'
11261127
self.preprocess = False
1127-
self.max_search_depth = max_search_depth+5 # Set to highest rank by default in main; need to store additional distances
1128+
max_rank = max(ranks)
1129+
self.max_search_depth = max(max_search_depth,max_rank+5) # Set to highest rank by default in main; need to store additional distances
11281130
# when there is redundancy (e.g. reciprocal matching, unique distance counting)
11291131
# or other sequences may be pruned out of the database
11301132
self.nn_dists = None # stores the unprocessed kNN at the maximum search depth
@@ -1139,6 +1141,7 @@ def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only, count_un
11391141
self.reciprocal_only = reciprocal_only
11401142
self.count_unique_distances = count_unique_distances
11411143
self.dist_col = dist_col
1144+
self.resolution = lineage_resolution
11421145
self.use_gpu = use_gpu
11431146

11441147
def __save_sparse__(self, data, row, col, rank, n_samples, dtype, is_nn_dist = False):
@@ -1177,6 +1180,7 @@ def __reduce_rank__(self, higher_rank_sparse_mat, lower_rank, n_samples, dtype):
11771180
lower_rank,
11781181
self.reciprocal_only,
11791182
self.count_unique_distances,
1183+
self.resolution,
11801184
self.threads)
11811185
self.__save_sparse__(lower_rank_sparse_mat[2],
11821186
lower_rank_sparse_mat[0],
@@ -1185,7 +1189,7 @@ def __reduce_rank__(self, higher_rank_sparse_mat, lower_rank, n_samples, dtype):
11851189
n_samples,
11861190
dtype)
11871191

1188-
def fit(self, X, accessory):
1192+
def fit(self, X):
11891193
'''Extends :func:`~ClusterFit.fit`
11901194
11911195
Gets assignments by using nearest neigbours.
@@ -1194,8 +1198,6 @@ def fit(self, X, accessory):
11941198
X (numpy.array)
11951199
The core and accessory distances to cluster. Must be set if
11961200
preprocess is set.
1197-
accessory (bool)
1198-
Use accessory rather than core distances
11991201
12001202
Returns:
12011203
y (numpy.array)
@@ -1205,23 +1207,20 @@ def fit(self, X, accessory):
12051207
ClusterFit.fit(self, X)
12061208
sample_size = int(round(0.5 * (1 + np.sqrt(1 + 8 * X.shape[0]))))
12071209
if (max(self.ranks) >= sample_size):
1208-
sys.stderr.write("Rank must be less than the number of samples")
1210+
sys.stderr.write("Maximum rank must be less than the number of samples: " + str(sample_size) + "\n")
12091211
sys.exit(0)
12101212

1211-
if accessory:
1212-
self.dist_col = 1
1213-
else:
1214-
self.dist_col = 0
1213+
search_depth = min(self.max_search_depth,sample_size-1)
12151214

12161215
row, col, data = \
12171216
poppunk_refine.get_kNN_distances(
12181217
distMat=pp_sketchlib.longToSquare(distVec=X[:, [self.dist_col]],
12191218
num_threads=self.threads),
1220-
kNN=self.max_search_depth,
1219+
kNN=search_depth,
12211220
dist_col=self.dist_col,
12221221
num_threads=self.threads
12231222
)
1224-
self.__save_sparse__(data, row, col, self.max_search_depth, sample_size, X.dtype,
1223+
self.__save_sparse__(data, row, col, search_depth, sample_size, X.dtype,
12251224
is_nn_dist = True)
12261225

12271226
# Apply filtering of links if requested and extract lower ranks - parallelisation within C++ code
@@ -1258,7 +1257,8 @@ def save(self):
12581257
self.max_search_depth,
12591258
self.reciprocal_only,
12601259
self.count_unique_distances,
1261-
self.dist_col],
1260+
self.dist_col,
1261+
self.resolution],
12621262
self.type],
12631263
pickle_file)
12641264

@@ -1271,7 +1271,7 @@ def load(self, fit_npz, fit_obj):
12711271
fit_obj (sklearn.mixture.BayesianGaussianMixture)
12721272
The saved fit object
12731273
'''
1274-
self.ranks, self.max_search_depth, self.reciprocal_only, self.count_unique_distances, self.dist_col = fit_obj
1274+
self.ranks, self.max_search_depth, self.reciprocal_only, self.count_unique_distances, self.dist_col, self.resolution = fit_obj
12751275
self.nn_dists = fit_npz
12761276
self.fitted = True
12771277

0 commit comments

Comments
 (0)