-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathpartitiondriver.py
More file actions
2552 lines (2272 loc) · 202 KB
/
partitiondriver.py
File metadata and controls
2552 lines (2272 loc) · 202 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import absolute_import, division, unicode_literals
from __future__ import print_function
import json
import numpy
import time
import sys
import itertools
import math
import os
import glob
import csv
from io import open
csv.field_size_limit(sys.maxsize) # make sure we can write very large csv fields
import random
from collections import OrderedDict
from subprocess import check_call, CalledProcessError
import copy
import multiprocessing
import operator
import traceback
from . import utils
from . import glutils
from . import indelutils
from . import treeutils
from . import lbplotting
from .glomerator import Glomerator
from .clusterpath import ClusterPath, ptnprint
from .waterer import Waterer
from .parametercounter import ParameterCounter
from .alleleclusterer import AlleleClusterer
from .alleleremover import AlleleRemover
from .allelefinder import AlleleFinder
from .performanceplotter import PerformancePlotter
from .partitionplotter import PartitionPlotter
from .hist import Hist
from . import seqfileopener
# ----------------------------------------------------------------------------------------
class PartitionDriver(object):
""" Class to parse input files, start bcrham jobs, and parse/interpret bcrham output for annotation and partitioning """
def __init__(self, args, glfo, input_info, simglfo, reco_info):
self.args = args
self.glfo = glfo
self.input_info = input_info
self.simglfo = simglfo
self.reco_info = reco_info
utils.prep_dir(self.args.workdir)
self.my_gldir = self.args.workdir + '/' + glutils.glfo_dir
self.vs_info, self.sw_info, self.msa_vs_info = None, None, None
self.duplicates = {}
self.bcrham_proc_info = None
self.timing_info = [] # it would be really nice to clean up both this and bcrham_proc_info
self.istep = None # stupid hack to get around network file system issues (see self.subworkidr()
self.subworkdirs = [] # arg. same stupid hack
self.unseeded_seqs = None # all the queries that we *didn't* cluster with the seed uid
self.small_cluster_seqs = None # all the queries that we removed after a few partition steps 'cause they were in small clusters
self.sw_param_dir, self.hmm_param_dir, self.multi_hmm_param_dir = ['%s/%s' % (self.args.parameter_dir, s) for s in ['sw', 'hmm', 'multi-hmm']]
self.sub_param_dir = utils.parameter_type_subdir(self.args, self.args.parameter_dir)
self.final_multi_paramdir = utils.non_none([self.args.parameter_out_dir, self.multi_hmm_param_dir]) # ick
self.hmm_infname = self.args.workdir + '/hmm_input.csv'
self.hmm_cachefname = self.args.workdir + '/hmm_cached_info.csv'
self.hmm_outfname = self.args.workdir + '/hmm_output.csv'
self.cpath_progress_dir = '%s/cluster-path-progress' % self.args.workdir # write the cluster paths for each clustering step to separate files in this dir
self.print_status = self.args.debug # if set, print some extra info (e.g. hmm calculation stats) that we used to print by default, but don't want to any more
if self.args.outfname is not None:
utils.prep_dir(dirname=None, fname=self.args.outfname, allow_other_files=True)
self.input_partition, self.input_cpath = None, None
if self.args.input_partition_fname is not None:
self.input_glfo, self.input_antn_list, self.input_cpath = utils.read_yaml_output(self.args.input_partition_fname, skip_annotations=not self.args.continue_from_input_partition)
if self.args.continue_from_input_partition:
self.input_antn_dict = utils.get_annotation_dict(self.input_antn_list, ignore_duplicates=True) # NOTE not really sure ignore_duplicates should be set, but when we're reading merged subset partitions it's nice to avoid the warnings, and in general duplicates doesn't seem like a big problem
self.input_partition = self.input_cpath.partitions[self.input_cpath.i_best if self.args.input_partition_index is None else self.args.input_partition_index]
# print(' %s input partition has duplicates: sum of cluster sizes %d vs %d unique queries' % (utils.wrnstr(), sum(len(c) for c in self.input_partition), len(set(u for c in self.input_partition for u in c))))
print(' --input-partition-fname: read %s partition with %d sequences in %d clusters from %s' % ('best' if self.args.input_partition_index is None else 'index-%d'%self.args.input_partition_index, sum(len(c) for c in self.input_partition), len(self.input_partition), self.args.input_partition_fname))
input_partition_queries = set(u for c in self.input_partition for u in c)
ids_to_rm = set(self.input_info) - input_partition_queries
for uid in ids_to_rm:
del self.input_info[uid]
if len(ids_to_rm) > 0:
print(' removed %d/%d queries from input info that were absent from input partition' % (len(ids_to_rm), len(self.input_info) + len(ids_to_rm)))
self.deal_with_persistent_cachefile()
self.cached_naive_hamming_bounds = self.args.naive_hamming_bounds # this exists so we don't get the bounds every iteration through the clustering loop (and here is just set to the value from the args, but is set for real below)
self.aligned_gl_seqs = None
if self.args.aligned_germline_fname is not None:
self.aligned_gl_seqs = glutils.read_aligned_gl_seqs(self.args.aligned_germline_fname, self.glfo, self.args.locus, dont_warn_about_duplicates=True)
self.action_fcns = {
'cache-parameters' : self.cache_parameters,
'annotate' : self.annotate,
'partition' : self.partition,
'view-output' : self.read_existing_output,
'view-annotations' : self.read_existing_output,
'view-partitions' : self.read_existing_output,
'plot-partitions' : self.read_existing_output,
'get-selection-metrics' : self.read_existing_output,
'get-linearham-info' : self.read_existing_output,
'update-meta-info' : self.read_existing_output,
'view-alternative-annotations' : self.view_alternative_annotations,
}
# ----------------------------------------------------------------------------------------
def sw_cache_path(self, find_any=False):
if self.args.sw_cachefname is not None:
return utils.getprefix(self.args.sw_cachefname)
elif None not in [self.args.parameter_dir, self.input_info]:
if find_any:
fnames = glob.glob(self.args.parameter_dir + '/sw-cache*') # remain suffix-agnostic
if len(fnames) == 0:
raise Exception('couldn\'t find any sw cache files in %s, despite setting <find_any>' % self.args.parameter_dir)
return utils.getprefix(fnames[0])
else:
return self.args.parameter_dir + '/sw-cache-' + utils.uidhashstr(''.join(self.input_info.keys())) # remain suffix-agnostic
else:
return None
# ----------------------------------------------------------------------------------------
def get_cpath_progress_fname(self, istep):
return '%s/istep-%d.csv' % (self.cpath_progress_dir, istep)
# ----------------------------------------------------------------------------------------
def get_all_cpath_progress_fnames(self):
assert len(os.listdir(self.cpath_progress_dir)) == self.istep + 1 # not really checking for anything that has a decent chance of happening, it's more because I get the file names in a slightly different loop in self.clean()
return [self.get_cpath_progress_fname(istep) for istep in range(self.istep + 1)]
# ----------------------------------------------------------------------------------------
def run(self, actions):
self.all_actions = actions
for tmpaction in actions:
self.current_action = tmpaction # NOTE gets changed on the fly below, I think just in self.get_annotations_for_partitions() (which is kind of hackey, but I can't figure out a way to improve on it that wouldn't involve wasting a foolish amount of time rewriting things. Bottom line is that the control flow for different actions is really complicated, and that complexity is going to show up somewhere)
self.action_fcns[tmpaction]()
# ----------------------------------------------------------------------------------------
def clean(self):
if self.args.new_allele_fname is not None:
new_allele_region = 'v'
new_alleles = [(g, seq) for g, seq in self.glfo['seqs'][new_allele_region].items() if glutils.is_snpd(g)]
print(' writing %d new %s to %s' % (len(new_alleles), utils.plural_str('allele', len(new_alleles)), self.args.new_allele_fname))
with open(self.args.new_allele_fname, 'w') as outfile:
for name, seq in new_alleles:
outfile.write('>%s\n' % name)
outfile.write('%s\n' % seq)
# merge persistent and current cache files into the persistent cache file
if self.args.persistent_cachefname is not None:
lockfname = self.args.persistent_cachefname + '.lock'
while os.path.exists(lockfname):
print(' waiting for lock on %s' % lockfname)
time.sleep(0.5)
lockfile = open(lockfname, 'w')
if not os.path.exists(self.args.persistent_cachefname):
open(self.args.persistent_cachefname, 'w').close()
self.merge_files(infnames=[self.args.persistent_cachefname, self.hmm_cachefname], outfname=self.args.persistent_cachefname, dereplicate=True)
lockfile.close()
os.remove(lockfname)
if os.path.exists(self.hmm_cachefname):
os.remove(self.hmm_cachefname)
for subd in self.subworkdirs:
if os.path.exists(subd): # if there was only one proc for this step, it'll have already been removed
os.rmdir(subd)
if os.path.exists(self.cpath_progress_dir): # only exists for partitioning
for cpfname in self.get_all_cpath_progress_fnames():
os.remove(cpfname)
os.rmdir(self.cpath_progress_dir)
try:
os.rmdir(self.args.workdir)
except OSError:
raise Exception('workdir (%s) not empty: %s' % (self.args.workdir, ' '.join(os.listdir(self.args.workdir)))) # hm... you get weird recursive exceptions if you get here. Oh, well, it still works
# ----------------------------------------------------------------------------------------
def deal_with_persistent_cachefile(self):
if self.args.persistent_cachefname is None or not os.path.exists(self.args.persistent_cachefname): # nothin' to do (ham'll initialize it)
return
with open(self.args.persistent_cachefname) as cachefile:
reader = csv.DictReader(cachefile)
if set(reader.fieldnames) == set(utils.annotation_headers):
raise Exception('doesn\'t work yet')
print(' parsing annotation output file %s to partition cache file %s' % (self.args.persistent_cachefname, self.hmm_cachefname))
with open(self.hmm_cachefname, utils.csv_wmode()) as outcachefile:
writer = csv.DictWriter(outcachefile, utils.partition_cachefile_headers)
writer.writeheader()
for line in reader:
if line['v_gene'] == '': # failed
continue
utils.process_input_line(line)
outrow = {'unique_ids' : line['unique_ids'], 'naive_seq' : line['padlefts'][0] * utils.ambig_base + line['naive_seq'] + line['padrights'][0] * utils.ambig_base}
writer.writerow(outrow)
elif set(reader.fieldnames) == set(utils.partition_cachefile_headers): # headers are ok, so can just copy straight over
check_call(['cp', self.args.persistent_cachefname, self.hmm_cachefname])
else:
raise Exception('--persistent-cachefname %s has unexpected header list %s' % (self.args.persistent_cachefname, reader.fieldnames))
# ----------------------------------------------------------------------------------------
def run_waterer(self, count_parameters=False, write_parameters=False, write_cachefile=False, look_for_cachefile=False, require_cachefile=False, dbg_str=''):
print('smith-waterman%s' % ((' (%s)' % dbg_str) if dbg_str != '' else ''))
sys.stdout.flush()
self.vs_info = None # should already be None, but we want to make sure (if --no-sw-vsearch is set we need it to be None, and if we just removed unlikely alleles we need to rerun vsearch with the likely alleles)
if not self.args.no_sw_vsearch:
self.set_vsearch_info(get_annotations=True)
if self.args.simultaneous_true_clonal_seqs: # it might be better to just copy over the true indel info in this case? it depends what you're trying to test, and honestly really if you're using this option you just shouldn't be putting indels in your simulation to start with
print(' note: not running msa indel stuff for --simultaneous-true-clonal-seqs, so any families with shm indels within cdr3 will be split up before running the hmm. To fix this you\'ll either need to run set_msa_info() (which is fine and easy, but slow, and requires deciding whether to make sure to run parameter caching with the arg, or else rerun smith waterman with the msa indels')
if self.args.all_seqs_simultaneous and self.msa_vs_info is None: # only run the first time we run sw
self.set_msa_info(debug=self.args.debug)
look_for_cachefile, require_cachefile = False, False
print(' note: ignoring any existing sw cache file to ensure we\'re getting msa indel info') # the main use case for this is with 'annotate' or 'partition' on existing parameters that were run on the whole repertoire, so a) it shouldn't be a big deal to rerun and b) you probably don't want to run msa indel info when parameter caching. Also, it's not easy to figure out if msa indel info is in the sw cached file without first reading it
pre_failed_queries = self.sw_info['failed-queries'] if self.sw_info is not None else None # don't re-run on failed queries if this isn't the first sw run (i.e., if we're parameter caching)
waterer = Waterer(self.args, self.glfo, self.input_info, self.simglfo, self.reco_info, # NOTE if we're reading a cache file, this glfo gets replaced with the glfo from the file
count_parameters=count_parameters,
parameter_out_dir=self.sw_param_dir if write_parameters else None,
plot_annotation_performance=self.args.plot_annotation_performance,
duplicates=self.duplicates, pre_failed_queries=pre_failed_queries, aligned_gl_seqs=self.aligned_gl_seqs, vs_info=self.vs_info, msa_vs_info=self.msa_vs_info)
cache_path = self.sw_cache_path(find_any=require_cachefile)
cachefname = cache_path + ('.yaml' if self.args.sw_cachefname is None else utils.getsuffix(self.args.sw_cachefname)) # use yaml, unless csv was explicitly set on the command line
if look_for_cachefile or require_cachefile:
if os.path.exists(cache_path + '.csv'): # ...but if there's already an old csv, use that
cachefname = cache_path + '.csv'
else: # i.e. if we're not explicitly told to look for it (and it exists) then it should be out of date
waterer.clean_cache(cache_path) # hm, should this be <cachefname> instead of <cache_path>? i mean they're the same, but still
if (look_for_cachefile or require_cachefile) and os.path.exists(cachefname):
waterer.read_cachefile(cachefname)
else:
if require_cachefile:
raise Exception('sw cache file %s not found' % cachefname)
if look_for_cachefile:
print(' couldn\'t find sw cache file %s, so running sw%s' % (cachefname, ' (this is probably because --seed-unique-id is set to a sequence that wasn\'t in the input file on which we cached parameters [if it\'s inconvenient to put your seed sequences in your input file, you can avoid this by putting them instead in separate file and set --queries-to-include-fname])' if self.args.seed_unique_id is not None else ''))
waterer.run(cachefname if write_cachefile else None)
self.sw_info = waterer.info
self.sw_glfo = waterer.glfo # ick
for uid, dupes in waterer.duplicates.items(): # <waterer.duplicates> is <self.duplicates> OR'd into any new duplicates from this run
self.duplicates[uid] = dupes
# utils.compare_vsearch_to_sw(self.sw_info, self.vs_info) # only compares indels a.t.m.
# NOTE neither of these really works right here, since any time we change the germline set we in really need to go back and rerun sw. But we can't do them betweeen allele finding and running parameter counting sw, since then the counts are wrong
# # d j allele removal based on snps/counts (just printing for now)
# if count_parameters and self.current_action == 'cache-parameters': # I'm not sure this is precisely the criterion I want, but it does the job of not running dj removal printing when we're just annotating with existing parameters (which was causing a crash [key error] with inconsistent glfo)
# print ' testing d+j snp-based allele removal'
# alremover = AlleleRemover(self.glfo, self.args, simglfo=self.simglfo, reco_info=self.reco_info)
# alremover.finalize(gene_counts=None, annotations={q : self.sw_info[q] for q in self.sw_info['queries']}, regions=['d', 'j'], debug=self.args.debug_allele_finding)
# print ' (not actually removing d and j alleleremover genes)'
# # glutils.remove_genes(self.glfo, alremover.genes_to_remove, debug=True)
# # glutils.write_glfo('xxx _output/glfo-test', self.glfo)
# gene name-based allele removal:
# if self.args.n_max_alleles_per_gene is not None: # it would be nice to use AlleleRemover for this, but that's really set up more as a precursor to allele finding, so it ends up being pretty messy to implement
# gene_counts = utils.get_gene_counts_from_annotations({q : self.sw_info[q] for q in self.sw_info['queries']})
# glutils.remove_extra_alleles_per_gene(self.glfo, self.args.n_max_alleles_per_gene, gene_counts)
# # glutils.write_glfo(self.sw_param_dir + '/' + glutils.glfo_dir, self.glfo) # don't need to rewrite glfo above, since we haven't yet written parameters, but now we do/have
# ----------------------------------------------------------------------------------------
def set_vsearch_info(self, get_annotations=False): # NOTE setting match:mismatch to optimized values from sw (i.e. 5:-4) results in much worse shm indel performance, so we leave it at the vsearch defaults ('2:-4')
seqs = {sfo['unique_ids'][0] : sfo['seqs'][0] for sfo in self.input_info.values()}
self.vs_info = utils.run_vsearch('search', seqs, self.args.workdir + '/vsearch', threshold=0.3, glfo=self.glfo, print_time=True, vsearch_binary=self.args.vsearch_binary, get_annotations=get_annotations, no_indels=self.args.no_indels)
# ----------------------------------------------------------------------------------------
def set_msa_info(self, debug=False): # NOTE not running this for args.simultaneous_true_clonal_seqs any more, but i'm leaving the stuff in here for that arg in case I change my mind later
# ----------------------------------------------------------------------------------------
def run_msa(cluster): # NOTE that this is really slow, and could probably be sped up? But i don't really care, the only time you'd run on a lot of families is simulation with tons of indels, which just isn't an important use case
unln_seqfos = [{'name' : q, 'seq' : self.input_info[q]['seqs'][0]} for q in cluster] # ignore the indels that already cam from vsearch, combining them would be hard (and we want the rest of the vsearch info for other purposes)
if self.args.simultaneous_true_clonal_seqs and len(set(len(s['seq']) for s in unln_seqfos)) == 1: # if all the seqs are the same length, they almost certainly don't have shm indels
if debug:
print(' all %d seqs the same length, skipping' % len(unln_seqfos))
return {'gene-counts' : None, 'annotations' : OrderedDict(), 'failures' : []}
aln_seqfos = utils.align_many_seqs(unln_seqfos, extra_str=' ', debug=debug)
cseq = utils.cons_seq(aligned_seqfos=aln_seqfos, extra_str=' ', debug=debug)
indeld_cseq = [] # cons seq where we remove any "indels" (well, gaps in the msa) that are present in less than half the seqs
for ich, cons_char in enumerate(cseq):
msa_chars = [s['seq'][ich] for s in aln_seqfos]
n_gap_chars = len([c for c in msa_chars if c in utils.gap_chars])
# print ''.join(msa_chars), n_gap_chars, len(msa_chars), cons_char if n_gap_chars <= len(msa_chars) // 2 else ''
if n_gap_chars <= len(msa_chars) // 2: # if it's less than half gap chars, we want the cons char in indeld_cseq
indeld_cseq.append(cons_char)
indeld_cseq = ''.join(indeld_cseq)
if debug:
print(' indeld cons seq: %s' % indeld_cseq)
fglfo = glutils.get_empty_glfo(self.args.locus)
fglfo['seqs']['v'] = {'IGHVx-x*x' : indeld_cseq} # it's not a real v gene, it extends through the whole (vdj) sequence, but i have to put something here, and i think this won't cause problems
return utils.run_vsearch('search', {s['name'] : s['seq'] for s in unln_seqfos}, self.args.workdir + '/vsearch', threshold=0.3, glfo=fglfo, vsearch_binary=self.args.vsearch_binary, get_annotations=True) # don't really need to align again, but this gets us the cigar seqs automatically, and i REALLY don't want to write anything more to do with cigars (i.e. converting aln_seqfos to cigars)
# ----------------------------------------------------------------------------------------
print(' running mafft+vsearch for msa indel info for --all-seqs-simultaneous/--simultaneous-true-clonal-seqs')
if self.args.all_seqs_simultaneous: # if you set both of these, that's your problem, it doesn't make sense anyway
nsets = [[q for q in self.input_info]] # maybe i should exclude any that failed sw, but otoh if you set all simultaneous, that means you want *all* simultaneous
elif self.args.simultaneous_true_clonal_seqs:
nsets = utils.get_partition_from_reco_info(self.reco_info)
else:
assert False
all_antns, all_failed_queries = OrderedDict(), []
for cluster in nsets:
cfo = run_msa(cluster)
all_antns.update(cfo['annotations'])
all_failed_queries += cfo['failures']
self.msa_vs_info = {'gene-counts' : None, 'annotations' : all_antns, 'failures' : all_failed_queries}
# ----------------------------------------------------------------------------------------
def cache_parameters(self):
print('caching parameters')
# remove unlikely alleles (can only remove v alleles here, since we're using vsearch annotations, but that's ok since it's mostly a speed optimization)
if not self.args.dont_remove_unlikely_alleles:
self.set_vsearch_info(get_annotations=(self.args.debug_allele_finding and self.args.is_simu)) # we only use the annotations to print some debug info in alleleremover
alremover = AlleleRemover(self.glfo, self.args, simglfo=self.simglfo, reco_info=self.reco_info)
alremover.finalize({'v' : self.vs_info['gene-counts']}, annotations=(None if len(self.vs_info['annotations']) == 0 else self.vs_info['annotations']), debug=self.args.debug_allele_finding)
glutils.remove_genes(self.glfo, alremover.genes_to_remove)
self.vs_info = None # don't want to keep this around, since it has alignments against all the genes we removed (also maybe memory control)
alremover = None # memory control (not tested)
# (re-)add [new] alleles
if self.args.allele_cluster:
self.run_waterer(dbg_str='new-allele clustering')
alclusterer = AlleleClusterer(self.args, glfo=self.glfo, reco_info=self.reco_info, simglfo=self.simglfo)
alcluster_alleles = alclusterer.get_alleles(self.sw_info, debug=self.args.debug_allele_finding, plotdir=None if self.args.plotdir is None else self.args.plotdir + '/sw/alcluster')
if len(alcluster_alleles) > 0:
glutils.add_new_alleles(self.glfo, list(alcluster_alleles.values()), use_template_for_codon_info=False, simglfo=self.simglfo, debug=True)
if self.aligned_gl_seqs is not None:
glutils.add_missing_alignments(self.glfo, self.aligned_gl_seqs, debug=True)
alclusterer = None
if not self.args.dont_find_new_alleles:
self.run_waterer(dbg_str='new-allele fitting')
alfinder = AlleleFinder(self.glfo, self.args)
new_allele_info = alfinder.increment_and_finalize(self.sw_info, debug=self.args.debug_allele_finding) # incrementing and finalizing are intertwined since it needs to know the distribution of 5p and 3p deletions before it can increment
if self.args.plotdir is not None:
alfinder.plot(self.args.plotdir + '/sw', only_csv=self.args.only_csv_plots)
if len(new_allele_info) > 0:
glutils.restrict_to_genes(self.glfo, list(self.sw_info['all_best_matches']))
glutils.add_new_alleles(self.glfo, new_allele_info, debug=True, simglfo=self.simglfo, use_template_for_codon_info=False) # <remove_template_genes> stuff is handled in <new_allele_info> (also note, can't use template for codon info since we may have already removed it)
if self.aligned_gl_seqs is not None:
glutils.add_missing_alignments(self.glfo, self.aligned_gl_seqs, debug=True)
# get and write sw parameters
self.run_waterer(count_parameters=True, write_parameters=True, write_cachefile=True, dbg_str='writing parameters')
self.write_hmms(self.sw_param_dir) # note that this modifies <self.glfo>
if self.args.only_smith_waterman:
if self.args.outfname is not None: # NOTE this is _not_ identical to the sw cache file (e.g. padding, failed query writing, plus probably other stuff)
self.write_output(None, set(), write_sw=True)
return
# get and write hmm parameters
print('hmm')
sys.stdout.flush()
_, annotations, hmm_failures = self.run_hmm('viterbi', self.sw_param_dir, parameter_out_dir=self.hmm_param_dir, count_parameters=True, partition=self.input_partition)
if self.args.outfname is not None and self.current_action == self.all_actions[-1]:
self.write_output(list(annotations.values()), hmm_failures, cpath=self.input_cpath)
self.write_hmms(self.hmm_param_dir) # note that this modifies <self.glfo>
# ----------------------------------------------------------------------------------------
def annotate(self):
print('annotating (with %s)%s' % (self.sub_param_dir, ' (and star tree annotation, since --subcluster-annotation-size is None)' if self.args.subcluster_annotation_size is None else ''))
if self.sw_info is None:
self.run_waterer(look_for_cachefile=not self.args.write_sw_cachefile, write_cachefile=self.args.write_sw_cachefile, count_parameters=self.args.count_parameters)
if self.args.only_smith_waterman:
if self.args.outfname is not None: # NOTE this is _not_ identical to the sw cache file (e.g. padding, failed query writing, plus probably other stuff)
self.write_output(None, set(), write_sw=True) # note that if you're auto-parameter caching, this will just be rewriting an sw output file that's already there from parameter caching, but oh, well. If you're setting --only-smith-waterman and not using cache-parameters, you have only yourself to blame
return
print('hmm')
self.added_extra_clusters_to_annotate = False # ugh (see other places this gets set, and fcn in next line gets called)
annotations, hmm_failures = self.actually_get_annotations_for_clusters(clusters_to_annotate=self.input_partition)
if self.args.get_selection_metrics:
self.calc_tree_metrics(annotations) # adds tree metrics to <annotations>
if self.args.annotation_clustering: # VJ CDR3 clustering (NOTE it would probably be better to have this under 'partition' action, but it's historical and also not very important)
from . import annotationclustering
antn_ptn = annotationclustering.vollmers(annotations, self.args.annotation_clustering_threshold)
antn_cpath = ClusterPath(partition=antn_ptn)
self.get_annotations_for_partitions(antn_cpath) # get new annotations corresponding to <antn_ptn>
else:
if self.args.outfname is not None:
self.write_output(list(annotations.values()), hmm_failures, cpath=antn_cpath if self.args.annotation_clustering else self.input_cpath)
if self.args.plot_partitions or self.input_partition is not None and self.args.plotdir is not None:
assert self.input_partition is not None
partplotter = PartitionPlotter(self.args, glfo=self.glfo)
partplotter.plot(self.args.plotdir + '/partitions', self.input_partition, annotations, reco_info=self.reco_info, args=self.args)
if self.args.count_parameters and not self.args.dont_write_parameters:
self.write_hmms(self.final_multi_paramdir) # note that this modifies <self.glfo>
# ----------------------------------------------------------------------------------------
def calc_tree_metrics(self, annotation_dict, annotation_list=None, cpath=None):
if annotation_list is None:
annotation_list = list(annotation_dict.values())
if self.current_action == 'get-selection-metrics' and self.args.input_metafnames is not None: # presumably if you're running 'get-selection-metrics' with --input-metafnames set, that means you didn't add the affinities (+ other metafo) when you partitioned, so we need to add it now
seqfileopener.read_input_metafo(self.args.input_metafnames, annotation_list)
if self.args.seed_unique_id is not None: # restrict to seed cluster in the best partition (clusters from non-best partition have duplicate uids, which then make fasttree barf, and it doesn't seem worth the trouble to fix it now)
print(' --seed-unique-id: restricting selection metric calculation to seed cluster in best partition (mostly to avoid fasttree crash on duplicate uids)')
annotation_dict = OrderedDict([(uidstr, line) for uidstr, line in annotation_dict.items() if self.args.seed_unique_id in line['unique_ids'] and line['unique_ids'] in cpath.partitions[cpath.i_best]])
treeutils.add_smetrics(self.args, self.args.selection_metrics_to_calculate, annotation_dict, self.args.lb_tau, reco_info=self.reco_info, # NOTE keys in <annotation_dict> may be out of sync with 'unique_ids' if we add inferred ancestral seqs here
use_true_clusters=self.reco_info is not None, base_plotdir=self.args.plotdir, workdir=self.args.workdir,
outfname=self.args.selection_metric_fname, glfo=self.glfo, tree_inference_outdir=self.args.tree_inference_outdir, debug=self.args.debug)
# ----------------------------------------------------------------------------------------
def parse_existing_annotations(self, annotation_list, ignore_args_dot_queries=False, process_csv=False):
n_queries_read = 0
failed_query_strs, fake_paired_strs = set(), set()
new_annotation_list = []
for line in annotation_list:
if process_csv:
utils.process_input_line(line)
uidstr = ':'.join(line['unique_ids'])
if ('invalid' in line and line['invalid']) or line['v_gene'] == '': # first way is the new way, but we have to check the empty-v-gene way too for old files
failed_query_strs.add(uidstr)
if line.get('is_fake_paired', False):
fake_paired_strs.add(uidstr)
continue
if self.args.queries is not None and not ignore_args_dot_queries: # second bit is because when printing subcluster naive seqs, we want to read all the ones that have any overlap with self.args.queries, not just the exact cluster of self.args.queries
if len(set(self.args.queries) & set(line['unique_ids'])) == 0: # actually make sure this is the precise set of queries we want (note that --queries and line['unique_ids'] are both ordered, and this ignores that... oh, well, sigh.)
continue
if self.args.reco_ids is not None and line['reco_id'] not in self.args.reco_ids:
continue
utils.add_implicit_info(self.glfo, line)
new_annotation_list.append(line)
n_queries_read += 1
if self.args.n_max_queries > 0 and n_queries_read >= self.args.n_max_queries:
break
if len(failed_query_strs) > 0:
print('\n%d failed queries%s' % (len(failed_query_strs), '' if len(fake_paired_strs) == 0 else ' (%d were fake paired annotations)' % len(fake_paired_strs)))
return new_annotation_list, len(fake_paired_strs)
# ----------------------------------------------------------------------------------------
def view_alternative_annotations(self):
print(' %s getting alternative annotation information from existing output file. These results will only be meaningful if you had --calculate-alternative-annotations set when writing the output file (so that all subcluster annotations were stored). We can\'t check for that here directly, so instead we print this warning to make sure you had it set ;-)' % utils.color('yellow', 'note'))
# we used to require that you set --queries to tell us which to get, but I think now it makes sense to by default just get all of them (but not sure enough to delete this yet)
# if self.args.queries is None:
# _, cpath = self.read_existing_output(read_partitions=True)
# clusterstrs = []
# for cluster in sorted(cpath.partitions[cpath.i_best], key=len, reverse=True):
# clusterstrs.append(' %s' % ':'.join(cluster))
# raise Exception('in order to view alternative annotations, you have to specify (with --queries) a cluster from the final partition. Choose from the following:\n%s' % '\n'.join(clusterstrs))
cluster_annotations, cpath = self.read_existing_output(ignore_args_dot_queries=True, read_partitions=True, read_annotations=True) # note that even if we don't need the cpath to re-write output below, we need to set read_annotations=True, since the fcn gets confused otherwise and doesn't read the right cluster annotation file (for deprecated csv files)
clusters_to_use = cpath.partitions[cpath.i_best] if self.args.queries is None else [self.args.queries]
n_skipped = 0
for cluster in sorted(clusters_to_use, key=len, reverse=True):
if len(cluster) < self.args.min_selection_metric_cluster_size:
n_skipped += 1
continue
self.process_alternative_annotations(cluster, cluster_annotations, cpath=cpath, debug=True)
if n_skipped > 0:
print(' skipped %d clusters smaller than --min-selection-metric-cluster-size %d' % (n_skipped, self.args.min_selection_metric_cluster_size))
print(' note: rewriting output file with newly-calculated alternative annotation info')
self.write_output(list(cluster_annotations.values()), set(), cpath=cpath, dont_write_failed_queries=True) # I *think* we want <dont_write_failed_queries> set, because the failed queries should already have been written, so now they'll just be mixed in with the others in <annotations>
# ----------------------------------------------------------------------------------------
def get_index_restricted_clusters(self, cpath):
tptn = cpath.best()
if self.args.partition_index_to_print is not None:
if self.args.partition_index_to_print > len(cpath.partitions) - 1:
cpath.print_partitions()
print(' %s --partition-index-to-print %d too large for cpath with length %d, so ignoring it and using best partition' % (utils.wrnstr(), self.args.partition_index_to_print, len(cpath.partitions)))
else:
tptn = cpath.partitions[self.args.partition_index_to_print]
tptn = sorted(tptn, key=len, reverse=True) # NOTE we always want this sorted, i.e. dont_sort doesn't apply to this, since here we're only doing --cluster-indices, which says in its help message that we sort
if self.args.cluster_indices is None:
return tptn
else:
return [tptn[i] for i in self.args.cluster_indices]
# ----------------------------------------------------------------------------------------
def print_results(self, cpath, annotation_list, dont_sort=False, label_list=None, extra_str=''):
if label_list is not None:
assert len(label_list) == len(annotation_list)
seed_uid = self.args.seed_unique_id
true_partition = None
restricted_clusters = None
if cpath is not None and len(cpath.partitions) > 0:
if len(annotation_list) > 0: # this is here just so we get a warning if any of the clusters in the best partition are missing from the annotations
_ = utils.get_annotation_dict(annotation_list, cpath=cpath)
# it's expected that sometimes you'll write a seed partition cpath, but then when you read the file you don't bother to seed the seed id on the command line. The reverse, however, shouldn't happen
if seed_uid is not None and cpath.seed_unique_id != seed_uid:
print(' %s seed uids from args and cpath don\'t match %s %s ' % (utils.color('red', 'error'), self.args.seed_unique_id, cpath.seed_unique_id))
if self.args.cluster_indices is not None:
restricted_clusters = self.get_index_restricted_clusters(cpath)
seed_uid = cpath.seed_unique_id
n_to_print, ipart_center = None, None
if self.args.partition_index_to_print is not None:
print(' --partition-index-to-print: using non-default partition with index %d' % self.args.partition_index_to_print)
n_to_print, ipart_center = 1, self.args.partition_index_to_print
print('%s%s' % (extra_str, utils.color('green', 'partitions:')))
cpath.print_partitions(abbreviate=self.args.abbreviate, reco_info=self.reco_info, highlight_cluster_indices=self.args.cluster_indices,
calc_missing_values=('all' if cpath.n_seqs() < 500 else 'best'), print_partition_indices=True, n_to_print=n_to_print, ipart_center=ipart_center)
if not self.args.is_data and self.reco_info is not None: # if we're reading existing output, it's pretty common to not have the reco info even when it's simulation, since you have to also pass in the simulation input file on the command line
true_partition = utils.get_partition_from_reco_info(self.reco_info)
true_cp = ClusterPath(seed_unique_id=self.args.seed_unique_id)
true_cp.add_partition(true_partition, -1., 1)
print('%strue:' % extra_str)
# print utils.per_seq_correct_cluster_fractions(cpath.partitions[cpath.i_best], true_partition, reco_info=self.reco_info, seed_unique_id=self.args.seed_unique_id)
true_cp.print_partitions(self.reco_info, print_header=False, calc_missing_values='best', extrastr=extra_str, print_partition_indices=True)
if len(annotation_list) > 0:
print('%s%s' % (extra_str, utils.color('green', 'annotations:')))
if dont_sort:
sorted_annotations = annotation_list
else:
sorted_annotations = sorted(annotation_list, key=lambda l: len(l['unique_ids']), reverse=True)
if self.args.cluster_indices is not None:
print(' --cluster-indices: restricting to %d cluster%s with indices: %s' % (len(self.args.cluster_indices), utils.plural(len(self.args.cluster_indices)), ' '.join(str(i) for i in self.args.cluster_indices)))
# sorted_annotations = [sorted_annotations[iclust] for iclust in self.args.cluster_indices] # this is what it used to be, but this is wrong
antn_dict = utils.get_annotation_dict(sorted_annotations)
sorted_annotations = [antn_dict.get(':'.join(rc)) for rc in restricted_clusters]
if None in sorted_annotations:
print(' %s missing %d requested annotations' % (utils.color('yellow', 'warning'), sorted_annotations.count(None)))
sorted_annotations = [l for l in sorted_annotations if l is not None]
for iline, line in enumerate(sorted_annotations):
if self.args.only_print_best_partition and cpath is not None and cpath.i_best is not None and line['unique_ids'] not in cpath.partitions[cpath.i_best]:
continue
if (self.args.only_print_seed_clusters or self.args.seed_unique_id is not None) and seed_uid not in line['unique_ids']: # we only use the seed id from the command line here, so you can print all the clusters even if you ran seed partitioning UPDATE wait did I change my mind? need to check
continue
if self.args.only_print_queries_to_include_clusters and len(set(self.args.queries_to_include) & set(line['unique_ids'])) == 0: # will barf if you don't tell us what queries to include, but then that's your fault isn't it
continue
if self.args.print_trees:
treestr = line.get('tree', lbplotting.get_tree_in_line(line, self.args.is_simu)) # ok this weird, but i want to be able to print the tree on simulation files even without setting --is-simu, sincei if --is-simu is set i may not be able to print the annotations (since for that, reco_info has to be set, but that depends how the simulation file was read. Anyway...
if treestr is None:
print(' --print-trees: no tree found in line')
else:
dtree = treeutils.get_dendro_tree(treestr=treestr)
print(utils.pad_lines(treeutils.get_ascii_tree(dendro_tree=dtree)))
continue # eh, maybe just continue so it doesn't crash if it sees the multi-seq true annotation
label, post_label = [], []
if self.args.infname is not None and self.reco_info is not None:
utils.print_true_events(self.simglfo, self.reco_info, line, full_true_partition=true_partition, extra_str=extra_str+' ')
label += ['inferred:']
if cpath is not None and cpath.i_best is not None: # maybe I could do the iparts stuff even if i_best isn't set, but whatever, I think it's only really not set if the cpath is null anyway
iparts = cpath.find_iparts_for_cluster(line['unique_ids'])
ipartstr = 'none' if len(iparts) == 0 else ' '.join([str(i) for i in iparts])
post_label += [' partition%s: %s' % (utils.plural(len(iparts)), ipartstr)]
if cpath.i_best in iparts:
post_label += [', %s' % utils.color('yellow', 'best')]
queries_to_emphasize = []
if seed_uid is not None and seed_uid in line['unique_ids']:
post_label += [', %s' % utils.color('red', 'seed')]
queries_to_emphasize += [seed_uid]
if self.args.queries_to_include is not None and len(set(self.args.queries_to_include) & set(line['unique_ids'])) > 0: # will barf if you don't tell us what queries to include, but then that's your fault isn't it
post_label += [', %s' % utils.color('red', 'queries-to-include')]
queries_to_emphasize += self.args.queries_to_include
if label_list is not None:
post_label += [label_list[iline]]
utils.print_reco_event(line, extra_str=extra_str+' ', label=''.join(label), post_label=''.join(post_label), queries_to_emphasize=queries_to_emphasize, extra_print_keys=self.args.extra_print_keys)
# ----------------------------------------------------------------------------------------
def restrict_ex_out_clusters(self, cpath, annotation_list): # NOTE would be nice to use [bits of] this also for result printing fcn above, but there we need the indices to line up, so have to do the 'continue' thing
n_before = len(annotation_list)
ptn_to_use, dbg_str = cpath.best() , []
if self.args.only_print_best_partition and cpath is not None and cpath.i_best is not None:
annotation_list = [l for l in annotation_list if l['unique_ids'] in cpath.partitions[cpath.i_best]]
if self.args.only_print_seed_clusters or self.args.seed_unique_id is not None:
annotation_list = [l for l in annotation_list if self.args.seed_unique_id in l['unique_ids']]
ptn_to_use = [c for c in ptn_to_use if self.args.seed_unique_id in c]
if self.args.only_print_queries_to_include_clusters:
annotation_list = [l for l in annotation_list if len(set(self.args.queries_to_include) & set(l['unique_ids'])) > 0] # will barf if you don't tell us what queries to include, but then that's your fault isn't it
ptn_to_use = [c for c in ptn_to_use if len(set(self.args.queries_to_include) & set(c)) > 0]
if self.args.n_final_clusters is not None or self.args.min_largest_cluster_size is not None:
tptns = cpath.partitions
if self.args.n_final_clusters is not None:
tptns = [p for p in tptns if len(p) == self.args.n_final_clusters]
dbg_str.append('--n-final-clusters')
if self.args.min_largest_cluster_size is not None:
tptns = [p for p in cpath.partitions if any(len(c) >= self.args.min_largest_cluster_size for c in p)]
dbg_str.append('--min-largest-cluster-size')
if len(tptns) > 1:
print(' %s multiple partitions satisfy --n-final-clusters/--min-largest-cluster-size criteria, just picking first one' % utils.wrnstr())
ptn_to_use = cpath.partitions[-1] if len(tptns)==0 else tptns[0]
annotation_list = [l for l in annotation_list if l['unique_ids'] in ptn_to_use]
if self.args.partition_index_to_print is not None or self.args.cluster_indices is not None:
ptn_to_use = self.get_index_restricted_clusters(cpath)
annotation_list = [l for l in annotation_list if l['unique_ids'] in ptn_to_use]
if self.args.partition_index_to_print is not None:
dbg_str.append('--partition-index-to-print')
if self.args.cluster_indices is not None:
dbg_str.append('--cluster-indices')
if self.args.only_print_best_partition or self.args.only_print_seed_clusters or self.args.only_print_queries_to_include_clusters or len(dbg_str) > 0:
astr = ', '.join(['--only-print-'+s for s in ['best-partition', 'seed-clusters', 'queries-to-include-clusters'] if getattr(self.args, ('only-print-'+s).replace('-', '_'))] + dbg_str)
print(' %s: restricting to %d/%d annotations' % (astr, len(annotation_list), n_before))
else:
print(' note: By default we print/operate on *all* annotations in the output file, which in general can include annotations from non-best partititons and non-seed clusters (e.g. if --n-final-clusters was set).\n If you want to restrict to particular annotations, use one of --only-print-best-partition, --only-print-seed-clusters, or --only-print-queries-to-include-clusters (or, if set during partitioning, --n-final-clusters or --min-largest-cluster-size).')
return ptn_to_use, annotation_list
# ----------------------------------------------------------------------------------------
def read_existing_output(self, outfname=None, ignore_args_dot_queries=False, read_partitions=False, read_annotations=False):
if outfname is None:
outfname = self.args.outfname
annotation_list = []
cpath = None
tmpact = self.current_action # just a shorthand for brevity
if utils.getsuffix(outfname) == '.csv': # old way
if tmpact == 'view-partitions' or tmpact == 'plot-partitions' or tmpact == 'view-output' or tmpact == 'get-selection-metrics' or read_partitions:
cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id, fname=outfname)
if tmpact == 'view-annotations' or tmpact == 'plot-partitions' or tmpact == 'view-output' or tmpact == 'get-selection-metrics' or read_annotations:
csvfile = open(outfname if cpath is None else self.args.cluster_annotation_fname) # closes on function exit, and no this isn't a great way of doing it (but it needs to stay open for the loop)
reader = csv.DictReader(csvfile)
if 'unique_ids' not in reader.fieldnames:
raise Exception('not an annotation file: %s' % outfname)
annotation_list = list(reader)
elif utils.getsuffix(outfname) == '.yaml': # new way
# NOTE replaces <self.glfo>, which is definitely what we want (that's the point of putting glfo in the yaml file), but it's still different behavior than if reading a csv
assert self.glfo is None # make sure bin/partis successfully figured out that we would be reading the glfo from the yaml output file
self.glfo, annotation_list, cpath = utils.read_yaml_output(outfname, n_max_queries=self.args.n_max_queries, dont_add_implicit_info=True, seed_unique_id=self.args.seed_unique_id) # add implicit info below, so we can skip some of 'em
else:
raise Exception('unhandled annotation file suffix %s' % outfname)
annotation_list, n_fake_paired = self.parse_existing_annotations(annotation_list, ignore_args_dot_queries=ignore_args_dot_queries, process_csv=utils.getsuffix(outfname) == '.csv') # NOTE modifies <annotation_list>
ptn_to_use, annnotation_list = self.restrict_ex_out_clusters(cpath, annotation_list)
if len(annotation_list) == 0:
if cpath is not None and tmpact in ['view-output', 'view-annotations', 'view-partitions']:
self.print_results(cpath, []) # used to just return, but now i want to at least see the cpath
print('zero annotations to print, exiting%s' % ('' if n_fake_paired==0 else ' (%s %d were fake paired annotations)'%(utils.color('yellow', 'note'), n_fake_paired)))
return
annotation_dict = utils.get_annotation_dict(annotation_list) # returns none type if there's duplicate annotations
extra_headers = list(set([h for l in annotation_list for h in l.keys() if h not in utils.annotation_headers])) # note that this basically has to be hackey/wrong, since we're trying to guess what the headers were when the file was written
if tmpact == 'get-linearham-info':
self.input_info = OrderedDict([(u, {'unique_ids' : [u], 'seqs' : [s]}) for l in annotation_list for u, s in zip(l['unique_ids'], l['input_seqs'])]) # this is hackey, but I think is ok (note that the order won't be the same as it would've been before)
self.run_waterer(require_cachefile=True)
uids_missing_sw_info = [u for l in annotation_list for u in l['unique_ids'] if u not in self.sw_info] # make sure <annotation_list> and <self.sw_info> have the same uids (they can get out of sync because we're re-running sw here with potentially different options (or versions) to whatever command was run to make <annotation_list>, e.g. if --is-simu was turned on duplicates won't have been removed before)
dup_dict = {d : u for u in self.sw_info['queries'] for d in self.sw_info[u]['duplicates'][0]}
for uid in uids_missing_sw_info: # NOTE <self.sw_info> is somewhat inconsistent after we do this, but this code should only get run when we're just adding linearham info so idgaf
if uid in dup_dict:
self.sw_info[uid] = self.sw_info[dup_dict[uid]] # it would be proper to fix the duplicates in here, and probably some other things
else:
pass # switching this to pass: even though I think it can ony happen if people are running with options that don't really make sense, I don't think there's really a pressing need to crash here raise Exception('no sw info for query %s' % uid) # I can't really do anything else, it makes no sense to go remove it from <annotation_list> when the underlying problem is (probably) that sw info was just run with different options
self.write_output(annotation_list, set(), cpath=cpath, outfname=self.args.linearham_info_fname, dont_write_failed_queries=True, extra_headers=extra_headers) # I *think* we want <dont_write_failed_queries> set, because the failed queries should already have been written, so now they'll just be mixed in with the others in <annotation_list>
if tmpact == 'get-selection-metrics':
self.calc_tree_metrics(annotation_dict, annotation_list=annotation_list, cpath=cpath) # adds tree metrics to <annotations>
if tmpact == 'update-meta-info':
inids, alist_ids = set(self.input_info), set(u for l in annotation_list for u in l['unique_ids'])
if inids != alist_ids:
print(' %s input info (len %d) has different uids to annotation list (len %d), and meta info will only be set/correct for the ones in input info (%d missing from input info, %d missing from annotation list, %d in common)' % (utils.wrnstr(), len(inids), len(alist_ids), len(alist_ids - inids), len(inids - alist_ids), len(inids & alist_ids)))
# may want to add this? not sure: overwrite_all=True
seqfileopener.add_input_metafo(self.input_info, annotation_list, keys_not_to_overwrite=['multiplicities', 'paired-uids']) # these keys are modified by sw (multiplicities) or paired clustering (paired-uids), so if you want to update them with this action here you're out of luck
if tmpact == 'update-meta-info' or (tmpact == 'get-selection-metrics' and self.args.add_selection_metrics_to_outfname):
print(' rewriting output file with %s: %s' % ('newly-calculated selection metrics' if tmpact=='get-selection-metrics' else 'updated input meta info', outfname))
if self.args.add_selection_metrics_to_outfname and 'gctree' in self.args.tree_inference_method:
print(' %s writing gctree annotations (with inferred ancestral sequences added) to original output file, which means that if you rerun gctree things may crash/be messed up since the inferred ancestral sequences are already in the annotation' % utils.wrnstr())
self.write_output(annotation_list, set(), cpath=cpath, dont_write_failed_queries=True, extra_headers=extra_headers) # I *think* we want <dont_write_failed_queries> set, because the failed queries should already have been written, so now they'll just be mixed in with the others in <annotation_list>
if self.args.align_constant_regions:
utils.parse_constant_regions(self.args.species, self.args.locus, annotation_list, self.args.workdir, csv_outdir=os.path.realpath(os.path.dirname(self.args.outfname)) if self.args.outfname is not None else None, debug=self.args.debug)
if tmpact == 'plot-partitions':
partplotter = PartitionPlotter(self.args, glfo=self.glfo)
partplotter.plot(self.args.plotdir + '/partitions', ptn_to_use, annotation_dict, reco_info=self.reco_info, args=self.args)
if tmpact in ['view-output', 'view-annotations', 'view-partitions']:
self.print_results(cpath, annotation_list)
return annotation_dict, cpath
# ----------------------------------------------------------------------------------------
def partition(self):
""" Partition sequences in <self.input_info> into clonally related lineages """
print('partitioning (with %s)' % self.sub_param_dir)
if self.sw_info is None:
self.run_waterer(look_for_cachefile=not self.args.write_sw_cachefile, write_cachefile=self.args.write_sw_cachefile, count_parameters=False) # self.args.count_parameters) # run smith-waterman
if len(self.sw_info['queries']) == 0:
if self.args.outfname is not None:
self.write_output([], set())
return
if self.args.only_smith_waterman:
return
print('hmm')
# pre-cache hmm naive seq for each single query NOTE <self.current_action> is still 'partition' for this (so that we build the correct bcrham command line)
if self.args.persistent_cachefname is not None:
print(' --persistent-cachefname: using existing hmm cache file %s' % self.args.persistent_cachefname)
if self.args.persistent_cachefname is None or not os.path.exists(self.hmm_cachefname): # if the default (no persistent cache file), or if a not-yet-existing persistent cache file was specified
print('%scaching all %d naive sequences' % ('' if self.print_status else ' ', len(self.sw_info['queries'])), end='\n' if self.input_partition is not None and self.args.continue_from_input_partition else ' ')
if self.args.synthetic_distance_based_partition:
self.write_bcrham_cache_file([[q] for q in self.sw_info['queries']])
elif self.input_partition is not None and self.args.continue_from_input_partition:
self.write_bcrham_cache_file(self.input_partition, ctype='input')
else:
self.run_hmm('viterbi', self.sub_param_dir, precache_all_naive_seqs=True) # , n_procs=self.auto_nprocs(len(self.sw_info['queries']))
if self.args.simultaneous_true_clonal_seqs:
print(' --simultaneous-true-clonal-seqs: using true clusters instead of partitioning')
true_partition = [[uid for uid in cluster if uid in self.sw_info] for cluster in utils.get_partition_from_reco_info(self.reco_info)] # mostly just to remove duplicates, although I think there might be other reasons why a uid would be missing
cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id, partition=true_partition)
elif self.args.all_seqs_simultaneous:
print(' --all-seqs-simultaneous: using single cluster instead of partitioning')
one_clust_ptn = [[u for u in self.sw_info['queries']]]
cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id, partition=one_clust_ptn)
elif self.input_partition is not None and not self.args.continue_from_input_partition:
print(' --input-partition-fname: using input cpath instead of running partitioning')
cpath = self.input_cpath
elif self.args.naive_vsearch: # or self.args.naive_swarm:
cpath = self.cluster_with_naive_vsearch_or_swarm(parameter_dir=self.sub_param_dir)
else:
cpath = self.cluster_with_bcrham()
self.get_annotations_for_partitions(cpath)
self.check_partition(cpath.partitions[cpath.i_best])
# ----------------------------------------------------------------------------------------
def split_seeded_clusters(self, old_cpath): # NOTE similarity to clusterpath.remove_unseeded_clusters()
start = time.time()
seeded_clusters, unseeded_clusters = utils.split_partition_with_criterion(old_cpath.partitions[old_cpath.i_best_minus_x], lambda cluster: self.args.seed_unique_id in cluster)
self.unseeded_seqs = [uid for uclust in unseeded_clusters for uid in uclust] # note that we no longer expect them to all be singletons, since we're merging queries with identical naive seqs before passing to glomerator.cc
seeded_singleton_set = set([uid for sclust in seeded_clusters for uid in sclust]) # in case there's duplicates
seeded_partition = utils.collapse_naive_seqs(self.synth_sw_info(seeded_singleton_set), split_by_cdr3=True)
seeded_cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id)
seeded_cpath.add_partition(seeded_partition, -1., 1)
print(' removed %d sequences in unseeded clusters,' % len(self.unseeded_seqs), end=' ')
print('split %d seeded clusters into %d singletons, and merged these into %d clusters with identical naive seqs (%.1f sec)' % (len(seeded_clusters), len(seeded_singleton_set), len(seeded_cpath.partitions[seeded_cpath.i_best_minus_x]), time.time() - start))
return seeded_cpath
# ----------------------------------------------------------------------------------------
def remove_small_clusters(self, old_cpath):
assert self.small_cluster_seqs is None # at least for now, we want to call this only once (would otherwise need to modify it)
big_clusters, small_clusters = utils.split_partition_with_criterion(old_cpath.partitions[old_cpath.i_best_minus_x], lambda cluster: len(cluster) not in self.args.small_clusters_to_ignore)
if self.args.seed_unique_id is not None: # should probably be implemented at some point
print(' %s not specifically keeping --seed-unique-id sequence when removing small clusters' % utils.wrnstr())
if self.args.queries_to_include is not None:
kept_clusts = []
for ism, sclust in enumerate(small_clusters):
if any(q in sclust for q in self.args.queries_to_include):
small_clusters[ism] = None
big_clusters.append(sclust)
kept_clusts.append(sclust)
small_clusters = [c for c in small_clusters if c is not None]
if len(kept_clusts) > 0:
print(' --queries-to-include: keeping %d small clusters that include specified queries with sizes: %s' % (len(kept_clusts), ' '.join(str(len(c)) for c in sorted(kept_clusts, reverse=True))))
self.small_cluster_seqs = [sid for sclust in small_clusters for sid in sclust]
new_cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id)
new_cpath.add_partition(big_clusters, -1., 1)
ntot = sum(len(c) for c in old_cpath.best())
print(' --small-clusters-to-ignore: removing %d / %d (%.3f) sequences in %d / %d (%.3f) clusters (with sizes among %s)' % (len(self.small_cluster_seqs), ntot, len(self.small_cluster_seqs) / ntot, len(small_clusters), len(old_cpath.best()), len(small_clusters) / len(old_cpath.best()), ' '.join([str(sz) for sz in self.args.small_clusters_to_ignore])))
return new_cpath
# ----------------------------------------------------------------------------------------
def scale_n_procs_for_new_n_clusters(self, initial_nseqs, initial_nprocs, cpath):
new_n_clusters = len(cpath.partitions[cpath.i_best_minus_x]) # when removing small clusters, this is the number of clusters, not the number of sequences, but it's maybe still ok
int_initial_seqs_per_proc = max(1, int(float(initial_nseqs) / initial_nprocs))
new_n_procs = max(1, int(float(new_n_clusters) / int_initial_seqs_per_proc))
if new_n_clusters > 20:
new_n_procs *= 3 # multiply by something 'cause we're turning off the seed uid for the last few times through
if self.args.batch_system is None:
new_n_procs = min(new_n_procs, multiprocessing.cpu_count())
new_n_procs = min(new_n_procs, self.args.n_procs) # don't let it be bigger than whatever was initially specified
print(' new n_procs %d (initial seqs/proc: %.2f new seqs/proc: %.2f' % (new_n_procs, float(initial_nseqs) / initial_nprocs, float(new_n_clusters) / new_n_procs))
return new_n_procs
# ----------------------------------------------------------------------------------------
def shall_we_reduce_n_procs(self, last_n_procs):
if self.timing_info[-1]['total'] < self.args.min_hmm_step_time: # mostly for when you're running on really small samples
return True
n_calcd_per_process = self.get_n_calculated_per_process()
if n_calcd_per_process < self.args.n_max_to_calc_per_process and last_n_procs > 2: # should be replaced by time requirement, since especially in later iterations, the larger clusters make this a crappy metric (2 is kind of a special case, becase, well, small integers and all)
return True
times_to_try_this_n_procs = max(4, last_n_procs) # if we've already milked this number of procs for most of what it's worth (once you get down to 2 or 3, you don't want to go lower)
if self.n_proc_list.count(last_n_procs) >= times_to_try_this_n_procs:
return True
return False
# ----------------------------------------------------------------------------------------
def prepare_next_iteration(self, cpath, initial_nseqs):
last_n_procs = self.n_proc_list[-1]
next_n_procs = last_n_procs
factor = 1.3
if self.shall_we_reduce_n_procs(last_n_procs):
next_n_procs = int(next_n_procs / float(factor))
def time_to_remove_some_seqs(n_proc_threshold):
return len(self.n_proc_list) >= n_proc_threshold or next_n_procs == 1
if self.args.small_clusters_to_ignore is not None and self.small_cluster_seqs is None and time_to_remove_some_seqs(self.args.n_steps_after_which_to_ignore_small_clusters):
cpath = self.remove_small_clusters(cpath)
next_n_procs = self.scale_n_procs_for_new_n_clusters(initial_nseqs, self.n_proc_list[0], cpath)
if self.args.seed_unique_id is not None and self.unseeded_seqs is None and time_to_remove_some_seqs(3): # if we didn't already remove the unseeded clusters in a previous partition step
if (self.args.n_final_clusters is not None or self.args.min_largest_cluster_size is not None) and not self.set_force_args: # need to add an additional iteration here with at least one of the force args set
self.set_force_args = True
next_n_procs = last_n_procs
else:
cpath = self.split_seeded_clusters(cpath)
next_n_procs = self.scale_n_procs_for_new_n_clusters(initial_nseqs, self.n_proc_list[0], cpath)
return next_n_procs, cpath
# ----------------------------------------------------------------------------------------
def get_n_calculated_per_process(self):
assert self.bcrham_proc_info is not None
total = 0. # sum over each process
for procinfo in self.bcrham_proc_info:
if procinfo['calcd'].get('vtb') is None or procinfo['calcd'].get('fwd') is None:
print('%s couldn\'t find vtb/fwd in:\n%s' % (utils.color('red', 'warning'), procinfo['calcd'])) # may as well not fail, it probably just means we lost some stdout somewhere (or are using the Zig backend which doesn't emit calcd debug strings). Which, ok, is bad, but let's say it shouldn't be fatal.
return 1. # er, or something?
if self.args.naive_hamming_cluster: # make sure we didn't accidentally calculate some fwds
assert procinfo['calcd']['fwd'] == 0.
total += procinfo['calcd']['vtb'] + procinfo['calcd']['fwd']
if self.args.debug:
print(' vtb + fwd calcd: %d (%.1f per proc)' % (total, float(total) / len(self.bcrham_proc_info)))
return float(total) / len(self.bcrham_proc_info)
# ----------------------------------------------------------------------------------------
def merge_shared_clusters(self, cpath, debug=False): # replace the most likely partition with a new partition in which any clusters that share a sequence have been merged
# cpath.partitions[cpath.i_best] = [['a', 'b', 'c', 'e'], ['d'], ['f', 'a'], ['g'], ['h'], ['i'], ['j', 'a'], ['x', 'y', 'z', 'd'], ['xx', 'x']]
partition = cpath.partitions[cpath.i_best]
if debug:
print('merging shared clusters')
cpath.print_partitions()
# find every pair of clusters that has some overlap
cluster_groups = []
if debug:
print(' making cluster_groups')
for iclust in range(len(partition)):
for jclust in range(iclust + 1, len(partition)):
if len(set(partition[iclust]) & set(partition[jclust])) > 0:
if debug:
print(' %d %d' % (iclust, jclust))
cluster_groups.append(set([iclust, jclust]))
# merge these pairs of clusters into groups
while True:
no_more_merges = True
for cp1, cp2 in itertools.combinations(cluster_groups, 2):
if len(cp1 & cp2) > 0:
if debug:
print(' merging %s and %s' % (cp1, cp2))
cluster_groups.append(cp1 | cp2)
cluster_groups.remove(cp1)
cluster_groups.remove(cp2)
no_more_merges = False
break # we've modified it now, so we have to go back and remake the iterator
if no_more_merges:
break
# actually merge the groups of clusters
new_clusters = []
for cgroup in cluster_groups:
new_clusters.append(list(set([uid for iclust in cgroup for uid in partition[iclust]])))
if debug:
print(' removing')
for iclust in sorted([i for cgroup in cluster_groups for i in cgroup], reverse=True):
if debug:
print(' %d' % iclust)
partition.pop(iclust)
for nclust in new_clusters:
partition.append(nclust)
if debug:
cpath.print_partitions()
# ----------------------------------------------------------------------------------------
def are_we_finished_clustering(self, n_procs, cpath):
if n_procs == 1:
return True
elif self.args.n_final_clusters is not None and len(cpath.partitions[cpath.i_best]) <= self.args.n_final_clusters: # NOTE I *think* I want the best, not best-minus-x here (hardish to be sure a.t.m., since I'm not really using the minus-x part right now)
print(' stopping with %d (<= %d) clusters' % (len(cpath.partitions[cpath.i_best]), self.args.n_final_clusters))
return True
elif self.args.max_cluster_size is not None and max([len(c) for c in cpath.partitions[cpath.i_best]]) > self.args.max_cluster_size: # NOTE I *think* I want the best, not best-minus-x here (hardish to be sure a.t.m., since I'm not really using the minus-x part right now)
print(' --max-cluster-size (partitiondriver): stopping with a cluster of size %d (> %d)' % (max([len(c) for c in cpath.partitions[cpath.i_best]]), self.args.max_cluster_size))
return True
else:
return False
# ----------------------------------------------------------------------------------------
def synth_sw_info(self, queries): # only used for passing info to utils.collapse_naive_seqs()
# this uses the cached hmm naive seqs (since we have them and they're better) but then later we pass the hmm the sw annotations, so we have to make sure the sw cdr3 length is the same within each cluster (it's very rare that it isn't)
synth_sw_info = {q : {'naive_seq' : s, 'cdr3_length' : self.sw_info[q]['cdr3_length']} for q, s in self.get_cached_hmm_naive_seqs(queries).items()} # NOTE code duplication in cluster_with_bcrham()
synth_sw_info['queries'] = list(synth_sw_info.keys())
return synth_sw_info
# ----------------------------------------------------------------------------------------
def init_cpath(self, n_procs):
initial_nseqs = len(self.sw_info['queries']) # maybe this should be the number of clusters, now that we're doing some preclustering here?
if self.input_partition is not None and self.args.continue_from_input_partition:
print(' --continue-from-input-partition: using input partition for initial cpath')
cpath = self.input_cpath
print(' %d clusters (%d seqs)' % (len(cpath.best()), sum(len(c) for c in cpath.best())))
# maybe i should split by cdr3?
# nsets = utils.split_clusters_by_cdr3(nsets, self.sw_info, warn=True)
else:
initial_nsets = utils.collapse_naive_seqs(self.synth_sw_info(self.sw_info['queries']), split_by_cdr3=True, debug=True)
cpath = ClusterPath(seed_unique_id=self.args.seed_unique_id)
cpath.add_partition(initial_nsets, logprob=0., n_procs=n_procs) # NOTE sw info excludes failed sequences (and maybe also sequences with different cdr3 length)
os.makedirs(self.cpath_progress_dir)
if self.args.debug:
print(' initial cpath:')
cpath.print_partitions(abbreviate=self.args.abbreviate, reco_info=self.reco_info)
return cpath, initial_nseqs
# ----------------------------------------------------------------------------------------
def merge_cpaths_from_previous_steps(self, final_cpath, debug=False):
if debug:
print('final (unmerged) cpath:')
final_cpath.print_partitions(abbreviate=self.args.abbreviate)
print('')
n_before, n_after = self.args.n_partitions_to_write, self.args.n_partitions_to_write # this takes more than we need, since --n-partitions-to-write is the *full* width, not half-width, but oh, well
if self.args.debug or (self.args.calculate_alternative_annotations and self.args.subcluster_annotation_size is None) or self.args.get_selection_metrics: # take all of 'em
n_before, n_after = sys.maxsize, sys.maxsize
elif self.args.write_additional_cluster_annotations is not None:
n_before, n_after = [max(waca, n_) for waca, n_ in zip(self.args.write_additional_cluster_annotations, (n_before, n_after))]
# NOTE we don't actually do anything with <n_after>, since we can't add any extra partitions here (well, we don't want to)
cpfnames = self.get_all_cpath_progress_fnames() # list of cpath files for each clustering step (last one corresponds to <final_cpath>)
if final_cpath.i_best >= n_before or len(cpfnames) < 2: # if we already have enough partitions, or if there was only one step, there's nothing to do
if debug:
print(' nothing to merge')
return final_cpath
icpfn = len(cpfnames) - 1
merged_cp = ClusterPath(fname=cpfnames[icpfn], seed_unique_id=self.args.seed_unique_id) # merged one is initially just the cp from the last step
assert merged_cp.partitions[merged_cp.i_best] == final_cpath.partitions[final_cpath.i_best] # shouldn't really be necessary, and is probably kind of slow
while merged_cp.i_best < n_before and icpfn > 0: # keep trying to add them until we have <n_before> of them
icpfn -= 1
previous_cp = ClusterPath(fname=cpfnames[icpfn], seed_unique_id=self.args.seed_unique_id)
for ip in range(len(merged_cp.partitions)):
if len(merged_cp.partitions[ip]) == len(previous_cp.partitions[-1]): # skip identical partitions (for speed, first check if they have the same number of clusters, then whether the clusters are the same)
if set([tuple(c) for c in merged_cp.partitions[ip]]) == set([tuple(c) for c in previous_cp.partitions[-1]]): # no, they're not always in the same order (I think because they get parcelled out to different processes, and then read back in random order)
if math.isinf(merged_cp.logprobs[ip]) and not math.isinf(previous_cp.logprobs[-1]): # it should only be possible for the *later* partition to have non-infinite logprob, since we usually only calculate full logprobs in the last clustering step (which is why we're taking the later partition, from merged_cp), so print an error if the earlier one, that we're about to throw away, is the one that's non-infinite
print('%s earlier partition (that we\'re discarding) has non-infinite logprob %f, while later partition\'s is infinite %f' % (utils.color('red', 'error'), previous_cp.logprobs[-1], merged_cp.logprobs[ip]))
previous_cp.remove_partition(len(previous_cp.partitions) - 1) # remove it from previous_cp, since we want the one that may have a logprob set (and which has smaller n_procs, although I don't think we care about that)
previous_cp.add_partition(list(merged_cp.partitions[ip]), merged_cp.logprobs[ip], merged_cp.n_procs[ip]) # add each partition in the existing merged cp to the previous cp
merged_cp = previous_cp
assert merged_cp.partitions[merged_cp.i_best] == final_cpath.partitions[final_cpath.i_best] # shouldn't really be necessary, and is probably kind of slow
if debug:
print('%s' % utils.color('red', str(icpfn)))
merged_cp.print_partitions()
if icpfn > 0:
print(' %s not merging entire cpath history' % utils.color('yellow', 'note'))
# this kind of sucks, and shouldn't be necessary, but someone reported that they're seeing cluster paths with all logprobs -inf (which is really bad since the actual ClusterPath code will know that the last partition should be the best, but someone reading the file by hand won't know). If I had a working example of this happening i could figure out a better place to put this check, but I don't
if all(l == float('-inf') for l in merged_cp.logprobs):
print(' %s all %d partitions in cluster path have log prob -inf, so setting last (best) to zero' % (utils.wrnstr(), len(merged_cp.partitions)))
merged_cp.logprobs[-1] = 0.
return merged_cp
# ----------------------------------------------------------------------------------------
def cluster_with_bcrham(self):
tmpstart = time.time()
self.set_force_args = False # annoying shenanigans to make sure that if both --seed-unique-id and either of --n-final-clusters or --min-largest-cluster-size are set, that the "force" args are set in bcrham *before* we remove unseeded seqs
n_procs = self.args.n_procs
cpath, initial_nseqs = self.init_cpath(n_procs)
self.n_proc_list = []
self.istep = 0
start = time.time()
while n_procs > 0:
if n_procs > len(cpath.bmx()):
print(' reducing n procs to number of clusters: %d --> %d' % (n_procs, len(cpath.bmx())))
n_procs = len(cpath.bmx())
print('%s%d clusters with %d proc%s%s' % ('' if self.print_status else ' ', len(cpath.bmx()), n_procs, utils.plural(n_procs), '\n' if self.print_status else ''), end=' ') # NOTE that a.t.m. i_best and i_best_minus_x are usually the same, since we're usually not calculating log probs of partitions (well, we're trying to avoid calculating any extra log probs, which means we usually don't know the log prob of the entire partition)
cpath, _, _ = self.run_hmm('forward', self.sub_param_dir, n_procs=n_procs, partition=cpath.bmx(), shuffle_input=True) # note that this annihilates the old <cpath>, which is a memory optimization (but we write all of them to the cpath progress dir)
self.n_proc_list.append(n_procs)
if self.are_we_finished_clustering(n_procs, cpath):
break
n_procs, cpath = self.prepare_next_iteration(cpath, initial_nseqs)
self.istep += 1
if self.args.max_cluster_size is not None:
print(' --max-cluster-size (partitiondriver): merging shared clusters')
self.merge_shared_clusters(cpath)
cpath = self.merge_cpaths_from_previous_steps(cpath)
print(' partition loop time: %.1f' % (time.time()-start))
return cpath
# ----------------------------------------------------------------------------------------
def check_partition(self, partition):
uids = set([uid for cluster in partition for uid in cluster])
input_ids = set(self.sw_info['queries']) # note that this does not include queries that were removed in sw
missing_ids = input_ids - uids
if self.unseeded_seqs is not None:
missing_ids -= set(self.unseeded_seqs)
if self.small_cluster_seqs is not None:
missing_ids -= set(self.small_cluster_seqs)