1616import sys
1717from array import array as pyarray
1818from collections import Counter ,defaultdict
19+ from operator import itemgetter
1920# ------------------------------------
2021# MACS3 modules
2122# ------------------------------------
@@ -362,6 +363,7 @@ def exclude(self, regions):
362363 :meth:`finalize` to refresh cached statistics.
363364 """
364365 i : cython .ulong
366+ j : cython .ulong
365367 k : bytes
366368 locs : cnp .ndarray
367369 locs_size : cython .ulong
@@ -1606,15 +1608,19 @@ def exclude(self, regions):
16061608 and finishes by calling :meth:`finalize`.
16071609 """
16081610 i : cython .ulong
1611+ j : cython .ulong
16091612 k : bytes
16101613 locs : cnp .ndarray
16111614 locs_size : cython .ulong
16121615 chrnames : set
16131616 regions_c : list
16141617 selected_idx : cnp .ndarray
16151618 regions_chrs : list
1616- r1 : cnp .void
1617- r2 : tuple
1619+ r1_start : cython .int
1620+ r1_end : cython .int
1621+ r1_count : cython .ushort
1622+ r2_start : cython .int
1623+ r2_end : cython .int
16181624 n_rl1 : cython .long
16191625 n_rl2 : cython .long
16201626
@@ -1640,45 +1646,58 @@ def exclude(self, regions):
16401646 selected_idx = np .ones (locs_size , dtype = bool )
16411647
16421648 regions_c = regions .regions [k ]
1649+ loc_starts = locs ['l' ]
1650+ loc_ends = locs ['r' ]
1651+ loc_counts = locs ['c' ]
1652+ region_starts = [r [0 ] for r in regions_c ]
1653+ region_ends = [r [1 ] for r in regions_c ]
16431654
16441655 i = 0
1656+ j = 0
16451657 n_rl1 = len (locs )
16461658 n_rl2 = len (regions_c )
1647- rl1_k = iter ( locs ). __next__
1648- rl2_k = iter ( regions_c ). __next__
1649- r1 = rl1_k ()
1659+ r1_start = loc_starts [ i ]
1660+ r1_end = loc_ends [ i ]
1661+ r1_count = loc_counts [ i ]
16501662 n_rl1 -= 1 # remaining rl1
1651- r2 = rl2_k ()
1663+ r2_start = region_starts [j ]
1664+ r2_end = region_ends [j ]
16521665 n_rl2 -= 1 # remaining rl2
16531666 while (True ):
16541667 # we do this until there is no r1 or r2 left.
1655- if r2 [ 0 ] < r1 [ 1 ] and r1 [ 0 ] < r2 [ 1 ] :
1668+ if r2_start < r1_end and r1_start < r2_end :
16561669 # since we found an overlap, r1 will be skipped/excluded
16571670 # and move to the next r1
16581671 # get rid of this one
16591672 n_rl1 -= 1
1660- self .length -= cython .cast (cython .ulonglong , (r1 [ 1 ] - r1 [ 0 ]) * r1 [ 2 ] )
1673+ self .length -= cython .cast (cython .ulonglong , (r1_end - r1_start ) * r1_count )
16611674 selected_idx [i ] = False
16621675
16631676 if n_rl1 >= 0 :
1664- r1 = rl1_k ()
16651677 i += 1
1678+ r1_start = loc_starts [i ]
1679+ r1_end = loc_ends [i ]
1680+ r1_count = loc_counts [i ]
16661681 continue
16671682 else :
16681683 break
1669- if r1 [ 1 ] < r2 [ 1 ] :
1684+ if r1_end < r2_end :
16701685 # in this case, we need to move to the next r1,
16711686 n_rl1 -= 1
16721687 if n_rl1 >= 0 :
1673- r1 = rl1_k ()
16741688 i += 1
1689+ r1_start = loc_starts [i ]
1690+ r1_end = loc_ends [i ]
1691+ r1_count = loc_counts [i ]
16751692 else :
16761693 # no more r1 left
16771694 break
16781695 else :
16791696 # in this case, we need to move the next r2
16801697 if n_rl2 :
1681- r2 = rl2_k ()
1698+ j += 1
1699+ r2_start = region_starts [j ]
1700+ r2_end = region_ends [j ]
16821701 n_rl2 -= 1
16831702 else :
16841703 # no more r2 left
@@ -1830,120 +1849,143 @@ def _anndata_ncls(self, regions):
18301849 adata = ad .AnnData (X = X , obs = obs , var = var )
18311850 return adata
18321851
1833- def _two_pointer_sweep (self ,regions ):
1834- barcode_items = list (self .barcode_dict .items ())
1835- barcodes = [b .decode () if isinstance (b , (bytes , bytearray )) else str (b ) for b , _ in barcode_items ]
1836- barcode_ids = np .array ([i for _ , i in barcode_items ], dtype = np .int64 )
1837-
1838- id_map = np .full (barcode_ids .max () + 1 , - 1 , dtype = np .int64 )
1839- id_map [barcode_ids ] = np .arange (len (barcode_ids ), dtype = np .int64 )
1840- n_cells = len (barcodes )
1841-
1842- # # array of locataions and counts
1843- # fragment_locs = petrack.locations[chrom][:petrack.size[chrom]]
1844- # # arrray of barcodes
1845- # fragment_barcodes = petrack.barcodes[chrom][:petrack.size[chrom]]
1852+ @cython .boundscheck (False ) # do not check that np indices are valid
1853+ @cython .cfunc
1854+ def _two_pointer_sweep (self , regions ):
1855+ frag_idx : cython .int
1856+ local_peak_idx : cython .int
1857+ n_frags : cython .int
1858+ remaining_frag_len : cython .int
1859+ remaining_peak_len : cython .int
1860+ back_trace_frag : cython .int
1861+ row_id : cython .int
1862+ peak_start : cython .int
1863+ peak_end : cython .int
1864+ frag_start : cython .int
1865+ frag_end : cython .int
1866+ barcode_items : list
1867+ barcode_ids : object
1868+ barcode_id_to_row : object
1869+ rows : list
1870+ columns : list
1871+ data : list
18461872
18471873 peak_names = []
18481874 peak_data = []
18491875
1850- rows = [] #barcodes index
1876+ rows = [] #barcode index
18511877 columns = [] #peaks index
1852- data = [] # count data int
1878+ data = [] #count data
18531879
1854- barcode_items = sorted (self .barcode_dict .items (), key = lambda x : x [ 1 ] )
1880+ barcode_items = sorted (self .barcode_dict .items (), key = itemgetter ( 1 ) )
18551881 barcodes = [b .decode () if isinstance (b , (bytes , bytearray )) else str (b ) for b , _ in barcode_items ]
1856- barcode_id_to_row = {barcode_id : i for i , (_ , barcode_id ) in enumerate (barcode_items )}
18571882 n_cells = len (barcodes )
1883+ if n_cells :
1884+ barcode_ids = np .fromiter ((barcode_id for _ , barcode_id in barcode_items ),
1885+ dtype = np .int64 ,
1886+ count = n_cells )
1887+ barcode_id_to_row = np .full (int (barcode_ids [- 1 ]) + 1 , - 1 , dtype = np .int64 )
1888+ barcode_id_to_row [barcode_ids ] = np .arange (n_cells , dtype = np .int64 )
1889+ else :
1890+ barcode_id_to_row = np .zeros (0 , dtype = np .int64 )
18581891
18591892 regions .sort ()
18601893 peak_counter = 0
1894+
18611895 for chrom in regions .regions .keys ():
1862- #barcodes = petrack.barcodes
1896+ if chrom not in self .locations :
1897+ continue
1898+
18631899 regions_c = regions .regions [chrom ]
1900+ if not regions_c :
1901+ continue
1902+
18641903 local_peak = []
1904+ peak_starts = []
1905+ peak_ends = []
1906+ chrom_str = chrom .decode () if isinstance (chrom , (bytes , bytearray )) else str (chrom )
18651907 ### regions empty skip
18661908 for (start , end ) in regions_c :
18671909 peak_counter += 1
18681910 peak_names .append (f"peak_{ peak_counter } " )
1869- peak_data .append ((chrom . decode (), start ,end ))
1911+ peak_data .append ((chrom_str , start , end ))
18701912 local_peak .append (peak_counter - 1 )
1913+ peak_starts .append (start )
1914+ peak_ends .append (end )
18711915
18721916 fragment_locs = self .locations [chrom ]
1873- fragment_barcodes = self .barcodes [chrom ]
18741917
18751918 if len (fragment_locs ) == 0 :
18761919 continue
18771920
1921+ fragment_barcodes = self .barcodes [chrom ]
1922+ frag_starts = fragment_locs ['l' ]
1923+ frag_ends = fragment_locs ['r' ]
1924+ frag_counts = fragment_locs ['c' ]
1925+ frag_rows = barcode_id_to_row [fragment_barcodes ]
1926+ n_frags = len (fragment_locs )
1927+
18781928 frag_idx = 0
18791929 local_peak_idx = 0
1880- frag_len = len (fragment_locs )
1881- peak_len = len (regions_c )
1882- frag = fragment_locs [frag_idx ]
1883- remaining_frag_len = frag_len - 1
1884- peak = regions_c [local_peak_idx ] # (start, end)
1885- remaining_peak_len = peak_len - 1
1886-
1887- # inside two_pointer_sweep, replace only the while-loop block with this
1930+ frag_start = frag_starts [frag_idx ]
1931+ frag_end = frag_ends [frag_idx ]
1932+ peak_start = peak_starts [local_peak_idx ]
1933+ peak_end = peak_ends [local_peak_idx ]
1934+ remaining_frag_len = n_frags - 1
1935+ remaining_peak_len = len (regions_c ) - 1
18881936 back_trace_frag = 0
1889- while True :
1890- frag_start , frag_end = frag [0 ], frag [1 ]
1891- peak_start , peak_end = peak [0 ], peak [1 ]
18921937
1893- # peak overlap fragment
1938+ while True :
18941939 if frag_start <= peak_end and peak_start <= frag_end :
1895- bc_id = int (fragment_barcodes [frag_idx ])
1896- row_id = barcode_id_to_row .get (bc_id , - 1 )
1940+ row_id = frag_rows [frag_idx ]
18971941 if row_id >= 0 :
18981942 rows .append (int (row_id ))
18991943 columns .append (local_peak [local_peak_idx ])
1900- data .append (int (frag [ 2 ]))
1944+ data .append (int (frag_counts [ frag_idx ]))
19011945
19021946 if frag_end > peak_end :
1903- # overhang case, perhaps the frag can still contain next peak(s)
19041947 back_trace_frag += 1
1905- remaining_frag_len -= 1
1906- if remaining_frag_len >= 0 :
1948+
1949+ if remaining_frag_len :
1950+ remaining_frag_len -= 1
19071951 frag_idx += 1
1908- frag = fragment_locs [frag_idx ]
1952+ frag_start = frag_starts [frag_idx ]
1953+ frag_end = frag_ends [frag_idx ]
19091954 continue
19101955 else :
19111956 break
19121957
1913- # if fragment ends before peak ends -> advance fragment
19141958 if frag_end < peak_end :
1915- remaining_frag_len -= 1
1916- if remaining_frag_len >= 0 :
1959+ if remaining_frag_len :
1960+ remaining_frag_len -= 1
19171961 frag_idx += 1
1918- frag = fragment_locs [frag_idx ]
1962+ frag_start = frag_starts [frag_idx ]
1963+ frag_end = frag_ends [frag_idx ]
19191964 else :
19201965 break
19211966 else :
1922- # advance peak
1923- remaining_peak_len -= 1
1924- if remaining_peak_len >= 0 :
1967+ if remaining_peak_len :
1968+ remaining_peak_len -= 1
19251969 local_peak_idx += 1
1926- peak = regions_c [local_peak_idx ]
1927- # we will also check if backtrace > 0 or not, if so, we need to move back the fragment pointer to check those peaks that were skipped due to overhang
1928- if back_trace_frag > 0 :
1970+ peak_start = peak_starts [local_peak_idx ]
1971+ peak_end = peak_ends [ local_peak_idx ]
1972+ if back_trace_frag :
19291973 frag_idx -= back_trace_frag
19301974 remaining_frag_len += back_trace_frag
1931- frag = fragment_locs [frag_idx ]
1932- back_trace_frag = 0
1975+ frag_start = frag_starts [frag_idx ]
1976+ frag_end = frag_ends [frag_idx ]
1977+ back_trace_frag = 0
19331978 else :
19341979 break
19351980
19361981 n_peaks = peak_counter
1937- x = sparse .coo_matrix ((data ,(rows ,columns )),shape = (n_cells ,n_peaks )).tocsr ()
19381982 obs = pd .DataFrame (index = barcodes )
1939- #obs['n_fragments_in_peaks'] = np.asarray(x.sum(axis=1)).ravel()
19401983 var = pd .DataFrame (peak_data , columns = ['chrom' , 'start' , 'end' ], index = peak_names )
19411984
1985+ x = sparse .csr_matrix ((data ,(rows ,columns )),shape = (n_cells ,n_peaks ))
19421986 adata_peaks_loop = ad .AnnData (X = x , obs = obs , var = var )
19431987 return adata_peaks_loop
19441988
1945-
1946-
19471989 def return_anndata (self , regions , method = "ncls" ):
19481990 """
19491991 Build barcode × peak AnnData using the selected method.
0 commit comments