Skip to content

Commit 20bc3de

Browse files
committed
Optimize two_pointer_sweep function
1 parent 2fca7e8 commit 20bc3de

2 files changed

Lines changed: 189 additions & 134 deletions

File tree

MACS3/Signal/PairedEndTrack.py

Lines changed: 109 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
from array import array as pyarray
1818
from collections import Counter,defaultdict
19+
from operator import itemgetter
1920
# ------------------------------------
2021
# MACS3 modules
2122
# ------------------------------------
@@ -362,6 +363,7 @@ def exclude(self, regions):
362363
:meth:`finalize` to refresh cached statistics.
363364
"""
364365
i: cython.ulong
366+
j: cython.ulong
365367
k: bytes
366368
locs: cnp.ndarray
367369
locs_size: cython.ulong
@@ -1606,15 +1608,19 @@ def exclude(self, regions):
16061608
and finishes by calling :meth:`finalize`.
16071609
"""
16081610
i: cython.ulong
1611+
j: cython.ulong
16091612
k: bytes
16101613
locs: cnp.ndarray
16111614
locs_size: cython.ulong
16121615
chrnames: set
16131616
regions_c: list
16141617
selected_idx: cnp.ndarray
16151618
regions_chrs: list
1616-
r1: cnp.void
1617-
r2: tuple
1619+
r1_start: cython.int
1620+
r1_end: cython.int
1621+
r1_count: cython.ushort
1622+
r2_start: cython.int
1623+
r2_end: cython.int
16181624
n_rl1: cython.long
16191625
n_rl2: cython.long
16201626

@@ -1640,45 +1646,58 @@ def exclude(self, regions):
16401646
selected_idx = np.ones(locs_size, dtype=bool)
16411647

16421648
regions_c = regions.regions[k]
1649+
loc_starts = locs['l']
1650+
loc_ends = locs['r']
1651+
loc_counts = locs['c']
1652+
region_starts = [r[0] for r in regions_c]
1653+
region_ends = [r[1] for r in regions_c]
16431654

16441655
i = 0
1656+
j = 0
16451657
n_rl1 = len(locs)
16461658
n_rl2 = len(regions_c)
1647-
rl1_k = iter(locs).__next__
1648-
rl2_k = iter(regions_c).__next__
1649-
r1 = rl1_k()
1659+
r1_start = loc_starts[i]
1660+
r1_end = loc_ends[i]
1661+
r1_count = loc_counts[i]
16501662
n_rl1 -= 1 # remaining rl1
1651-
r2 = rl2_k()
1663+
r2_start = region_starts[j]
1664+
r2_end = region_ends[j]
16521665
n_rl2 -= 1 # remaining rl2
16531666
while (True):
16541667
# we do this until there is no r1 or r2 left.
1655-
if r2[0] < r1[1] and r1[0] < r2[1]:
1668+
if r2_start < r1_end and r1_start < r2_end:
16561669
# since we found an overlap, r1 will be skipped/excluded
16571670
# and move to the next r1
16581671
# get rid of this one
16591672
n_rl1 -= 1
1660-
self.length -= cython.cast(cython.ulonglong, (r1[1] - r1[0])*r1[2])
1673+
self.length -= cython.cast(cython.ulonglong, (r1_end - r1_start) * r1_count)
16611674
selected_idx[i] = False
16621675

16631676
if n_rl1 >= 0:
1664-
r1 = rl1_k()
16651677
i += 1
1678+
r1_start = loc_starts[i]
1679+
r1_end = loc_ends[i]
1680+
r1_count = loc_counts[i]
16661681
continue
16671682
else:
16681683
break
1669-
if r1[1] < r2[1]:
1684+
if r1_end < r2_end:
16701685
# in this case, we need to move to the next r1,
16711686
n_rl1 -= 1
16721687
if n_rl1 >= 0:
1673-
r1 = rl1_k()
16741688
i += 1
1689+
r1_start = loc_starts[i]
1690+
r1_end = loc_ends[i]
1691+
r1_count = loc_counts[i]
16751692
else:
16761693
# no more r1 left
16771694
break
16781695
else:
16791696
# in this case, we need to move the next r2
16801697
if n_rl2:
1681-
r2 = rl2_k()
1698+
j += 1
1699+
r2_start = region_starts[j]
1700+
r2_end = region_ends[j]
16821701
n_rl2 -= 1
16831702
else:
16841703
# no more r2 left
@@ -1830,120 +1849,143 @@ def _anndata_ncls(self, regions):
18301849
adata = ad.AnnData(X=X, obs=obs, var=var)
18311850
return adata
18321851

1833-
def _two_pointer_sweep(self,regions):
1834-
barcode_items = list(self.barcode_dict.items())
1835-
barcodes = [b.decode() if isinstance(b, (bytes, bytearray)) else str(b) for b, _ in barcode_items]
1836-
barcode_ids = np.array([i for _, i in barcode_items], dtype=np.int64)
1837-
1838-
id_map = np.full(barcode_ids.max() + 1, -1, dtype=np.int64)
1839-
id_map[barcode_ids] = np.arange(len(barcode_ids), dtype=np.int64)
1840-
n_cells = len(barcodes)
1841-
1842-
# # array of locataions and counts
1843-
# fragment_locs = petrack.locations[chrom][:petrack.size[chrom]]
1844-
# # arrray of barcodes
1845-
# fragment_barcodes = petrack.barcodes[chrom][:petrack.size[chrom]]
1852+
@cython.boundscheck(False) # do not check that np indices are valid
1853+
@cython.cfunc
1854+
def _two_pointer_sweep(self, regions):
1855+
frag_idx: cython.int
1856+
local_peak_idx: cython.int
1857+
n_frags: cython.int
1858+
remaining_frag_len: cython.int
1859+
remaining_peak_len: cython.int
1860+
back_trace_frag: cython.int
1861+
row_id: cython.int
1862+
peak_start: cython.int
1863+
peak_end: cython.int
1864+
frag_start: cython.int
1865+
frag_end: cython.int
1866+
barcode_items: list
1867+
barcode_ids: object
1868+
barcode_id_to_row: object
1869+
rows: list
1870+
columns: list
1871+
data: list
18461872

18471873
peak_names = []
18481874
peak_data = []
18491875

1850-
rows = [] #barcodes index
1876+
rows = [] #barcode index
18511877
columns = [] #peaks index
1852-
data = [] # count data int
1878+
data = [] #count data
18531879

1854-
barcode_items = sorted(self.barcode_dict.items(), key=lambda x: x[1])
1880+
barcode_items = sorted(self.barcode_dict.items(), key=itemgetter(1))
18551881
barcodes = [b.decode() if isinstance(b, (bytes, bytearray)) else str(b) for b, _ in barcode_items]
1856-
barcode_id_to_row = {barcode_id: i for i, (_, barcode_id) in enumerate(barcode_items)}
18571882
n_cells = len(barcodes)
1883+
if n_cells:
1884+
barcode_ids = np.fromiter((barcode_id for _, barcode_id in barcode_items),
1885+
dtype=np.int64,
1886+
count=n_cells)
1887+
barcode_id_to_row = np.full(int(barcode_ids[-1]) + 1, -1, dtype=np.int64)
1888+
barcode_id_to_row[barcode_ids] = np.arange(n_cells, dtype=np.int64)
1889+
else:
1890+
barcode_id_to_row = np.zeros(0, dtype=np.int64)
18581891

18591892
regions.sort()
18601893
peak_counter = 0
1894+
18611895
for chrom in regions.regions.keys():
1862-
#barcodes = petrack.barcodes
1896+
if chrom not in self.locations:
1897+
continue
1898+
18631899
regions_c = regions.regions[chrom]
1900+
if not regions_c:
1901+
continue
1902+
18641903
local_peak = []
1904+
peak_starts = []
1905+
peak_ends = []
1906+
chrom_str = chrom.decode() if isinstance(chrom, (bytes, bytearray)) else str(chrom)
18651907
### regions empty skip
18661908
for (start, end) in regions_c:
18671909
peak_counter += 1
18681910
peak_names.append(f"peak_{peak_counter}")
1869-
peak_data.append((chrom.decode(),start,end))
1911+
peak_data.append((chrom_str, start, end))
18701912
local_peak.append(peak_counter - 1)
1913+
peak_starts.append(start)
1914+
peak_ends.append(end)
18711915

18721916
fragment_locs = self.locations[chrom]
1873-
fragment_barcodes = self.barcodes[chrom]
18741917

18751918
if len(fragment_locs) == 0:
18761919
continue
18771920

1921+
fragment_barcodes = self.barcodes[chrom]
1922+
frag_starts = fragment_locs['l']
1923+
frag_ends = fragment_locs['r']
1924+
frag_counts = fragment_locs['c']
1925+
frag_rows = barcode_id_to_row[fragment_barcodes]
1926+
n_frags = len(fragment_locs)
1927+
18781928
frag_idx = 0
18791929
local_peak_idx = 0
1880-
frag_len = len(fragment_locs)
1881-
peak_len = len(regions_c)
1882-
frag = fragment_locs[frag_idx]
1883-
remaining_frag_len = frag_len - 1
1884-
peak = regions_c[local_peak_idx] # (start, end)
1885-
remaining_peak_len = peak_len - 1
1886-
1887-
# inside two_pointer_sweep, replace only the while-loop block with this
1930+
frag_start = frag_starts[frag_idx]
1931+
frag_end = frag_ends[frag_idx]
1932+
peak_start = peak_starts[local_peak_idx]
1933+
peak_end = peak_ends[local_peak_idx]
1934+
remaining_frag_len = n_frags - 1
1935+
remaining_peak_len = len(regions_c) - 1
18881936
back_trace_frag = 0
1889-
while True:
1890-
frag_start, frag_end = frag[0], frag[1]
1891-
peak_start, peak_end = peak[0], peak[1]
18921937

1893-
# peak overlap fragment
1938+
while True:
18941939
if frag_start <= peak_end and peak_start <= frag_end:
1895-
bc_id = int(fragment_barcodes[frag_idx])
1896-
row_id = barcode_id_to_row.get(bc_id, -1)
1940+
row_id = frag_rows[frag_idx]
18971941
if row_id >= 0:
18981942
rows.append(int(row_id))
18991943
columns.append(local_peak[local_peak_idx])
1900-
data.append(int(frag[2]))
1944+
data.append(int(frag_counts[frag_idx]))
19011945

19021946
if frag_end > peak_end:
1903-
# overhang case, perhaps the frag can still contain next peak(s)
19041947
back_trace_frag += 1
1905-
remaining_frag_len -= 1
1906-
if remaining_frag_len >= 0:
1948+
1949+
if remaining_frag_len:
1950+
remaining_frag_len -= 1
19071951
frag_idx += 1
1908-
frag = fragment_locs[frag_idx]
1952+
frag_start = frag_starts[frag_idx]
1953+
frag_end = frag_ends[frag_idx]
19091954
continue
19101955
else:
19111956
break
19121957

1913-
# if fragment ends before peak ends -> advance fragment
19141958
if frag_end < peak_end:
1915-
remaining_frag_len -= 1
1916-
if remaining_frag_len >= 0:
1959+
if remaining_frag_len:
1960+
remaining_frag_len -= 1
19171961
frag_idx += 1
1918-
frag = fragment_locs[frag_idx]
1962+
frag_start = frag_starts[frag_idx]
1963+
frag_end = frag_ends[frag_idx]
19191964
else:
19201965
break
19211966
else:
1922-
# advance peak
1923-
remaining_peak_len -= 1
1924-
if remaining_peak_len >= 0:
1967+
if remaining_peak_len:
1968+
remaining_peak_len -= 1
19251969
local_peak_idx += 1
1926-
peak = regions_c[local_peak_idx]
1927-
# we will also check if backtrace > 0 or not, if so, we need to move back the fragment pointer to check those peaks that were skipped due to overhang
1928-
if back_trace_frag > 0:
1970+
peak_start = peak_starts[local_peak_idx]
1971+
peak_end = peak_ends[local_peak_idx]
1972+
if back_trace_frag:
19291973
frag_idx -= back_trace_frag
19301974
remaining_frag_len += back_trace_frag
1931-
frag = fragment_locs[frag_idx]
1932-
back_trace_frag = 0
1975+
frag_start = frag_starts[frag_idx]
1976+
frag_end = frag_ends[frag_idx]
1977+
back_trace_frag = 0
19331978
else:
19341979
break
19351980

19361981
n_peaks = peak_counter
1937-
x = sparse.coo_matrix((data,(rows,columns)),shape=(n_cells,n_peaks)).tocsr()
19381982
obs = pd.DataFrame(index=barcodes)
1939-
#obs['n_fragments_in_peaks'] = np.asarray(x.sum(axis=1)).ravel()
19401983
var = pd.DataFrame(peak_data, columns=['chrom', 'start', 'end'], index=peak_names)
19411984

1985+
x = sparse.csr_matrix((data,(rows,columns)),shape=(n_cells,n_peaks))
19421986
adata_peaks_loop = ad.AnnData(X=x, obs=obs, var=var)
19431987
return adata_peaks_loop
19441988

1945-
1946-
19471989
def return_anndata(self, regions, method = "ncls"):
19481990
"""
19491991
Build barcode × peak AnnData using the selected method.

0 commit comments

Comments
 (0)