Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish-to-gh-with-sphinx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build and Publish Sphinx Documentation

on:
push:
branches: ["master"]
branches: ["main"]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

Expand Down
208 changes: 196 additions & 12 deletions MACS3/Signal/PairedEndTrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
from array import array as pyarray
from collections import Counter,defaultdict
from operator import itemgetter
# ------------------------------------
# MACS3 modules
# ------------------------------------
Expand All @@ -38,6 +39,10 @@

from MACS3.Utilities.Logger import logging

import pandas as pd
from scipy import sparse
import anndata as ad

logger = logging.getLogger(__name__)
debug = logger.debug
info = logger.info
Expand Down Expand Up @@ -350,13 +355,15 @@ def exclude(self, regions):
----------
regions : MACS3.Signal.Region.Regions
Sorted region collection whose intervals should be excluded.
The default path merges overlapping or adjacent intervals on a copy.

Notes
-----
The operation mutates the track in place and finishes by calling
:meth:`finalize` to refresh cached statistics.
"""
i: cython.ulong
j: cython.ulong
k: bytes
locs: cnp.ndarray
locs_size: cython.ulong
Expand Down Expand Up @@ -1601,15 +1608,19 @@ def exclude(self, regions):
and finishes by calling :meth:`finalize`.
"""
i: cython.ulong
j: cython.ulong
k: bytes
locs: cnp.ndarray
locs_size: cython.ulong
chrnames: set
regions_c: list
selected_idx: cnp.ndarray
regions_chrs: list
r1: cnp.void
r2: tuple
r1_start: cython.int
r1_end: cython.int
r1_count: cython.ushort
r2_start: cython.int
r2_end: cython.int
n_rl1: cython.long
n_rl2: cython.long

Expand All @@ -1635,45 +1646,58 @@ def exclude(self, regions):
selected_idx = np.ones(locs_size, dtype=bool)

regions_c = regions.regions[k]
loc_starts = locs['l']
loc_ends = locs['r']
loc_counts = locs['c']
region_starts = [r[0] for r in regions_c]
region_ends = [r[1] for r in regions_c]

i = 0
j = 0
n_rl1 = len(locs)
n_rl2 = len(regions_c)
rl1_k = iter(locs).__next__
rl2_k = iter(regions_c).__next__
r1 = rl1_k()
r1_start = loc_starts[i]
r1_end = loc_ends[i]
r1_count = loc_counts[i]
n_rl1 -= 1 # remaining rl1
r2 = rl2_k()
r2_start = region_starts[j]
r2_end = region_ends[j]
n_rl2 -= 1 # remaining rl2
while (True):
# we do this until there is no r1 or r2 left.
if r2[0] < r1[1] and r1[0] < r2[1]:
if r2_start < r1_end and r1_start < r2_end:
# since we found an overlap, r1 will be skipped/excluded
# and move to the next r1
# get rid of this one
n_rl1 -= 1
self.length -= cython.cast(cython.ulonglong, (r1[1] - r1[0])*r1[2])
self.length -= cython.cast(cython.ulonglong, (r1_end - r1_start) * r1_count)
selected_idx[i] = False

if n_rl1 >= 0:
r1 = rl1_k()
i += 1
r1_start = loc_starts[i]
r1_end = loc_ends[i]
r1_count = loc_counts[i]
continue
else:
break
if r1[1] < r2[1]:
if r1_end < r2_end:
# in this case, we need to move to the next r1,
n_rl1 -= 1
if n_rl1 >= 0:
r1 = rl1_k()
i += 1
r1_start = loc_starts[i]
r1_end = loc_ends[i]
r1_count = loc_counts[i]
else:
# no more r1 left
break
else:
# in this case, we need to move the next r2
if n_rl2:
r2 = rl2_k()
j += 1
r2_start = region_starts[j]
r2_end = region_ends[j]
n_rl2 -= 1
else:
# no more r2 left
Expand All @@ -1690,7 +1714,167 @@ def exclude(self, regions):
selected_idx.resize(0, refcheck=False)
self.finalize()
return

@cython.boundscheck(False) # do not check that np indices are valid
@cython.cfunc
def _two_pointer_sweep(self, regions):
peak_idx: cython.int
n_regions_c: cython.int
n_cells: cython.int
peak_counter: cython.int
peak_base: cython.int
start: cython.int
end: cython.int
chrom: bytes
chrom_str: str
barcode_items: list
regions_c: list
barcode_ids: cnp.ndarray(cnp.int32_t, ndim=1)
barcode_id_to_row: cnp.ndarray(cnp.int32_t, ndim=1)
fragment_locs: cnp.ndarray
fragment_barcodes: cnp.ndarray(cnp.int32_t, ndim=1)
peak_starts: cnp.ndarray(cnp.int32_t, ndim=1)
peak_ends: cnp.ndarray(cnp.int32_t, ndim=1)
frag_starts: cnp.ndarray(cnp.int32_t, ndim=1)
frag_ends: cnp.ndarray(cnp.int32_t, ndim=1)
frag_counts: cnp.ndarray(cnp.uint16_t, ndim=1)
frag_rows: cnp.ndarray(cnp.int32_t, ndim=1)
left_idx: cnp.ndarray(cnp.int32_t, ndim=1)
right_idx: cnp.ndarray(cnp.int32_t, ndim=1)
widths: cnp.ndarray(cnp.int32_t, ndim=1)
valid_mask: cnp.ndarray
chunk_rows: cnp.ndarray(cnp.int32_t, ndim=1)
chunk_cols: cnp.ndarray(cnp.int32_t, ndim=1)
chunk_data: cnp.ndarray(cnp.int32_t, ndim=1)
valid_rows: cnp.ndarray(cnp.int32_t, ndim=1)
valid_counts: cnp.ndarray(cnp.int32_t, ndim=1)
valid_left: cnp.ndarray(cnp.int32_t, ndim=1)
valid_widths: cnp.ndarray(cnp.int32_t, ndim=1)
chunk_offsets: cnp.ndarray(cnp.int32_t, ndim=1)
repeated_offsets: cnp.ndarray(cnp.int32_t, ndim=1)
repeated_left: cnp.ndarray(cnp.int32_t, ndim=1)
intra_offsets: cnp.ndarray(cnp.int32_t, ndim=1)
rows_arr: cnp.ndarray(cnp.int32_t, ndim=1)
cols_arr: cnp.ndarray(cnp.int32_t, ndim=1)
data_arr: cnp.ndarray(cnp.int32_t, ndim=1)
chunk_nnz: cython.int
row_chunks: list
col_chunks: list
data_chunks: list
peak_names: list
peak_data: list
peak_names_append: object
peak_data_append: object

peak_names = []
peak_data = []
row_chunks = []
col_chunks = []
data_chunks = []
peak_names_append = peak_names.append
peak_data_append = peak_data.append

barcode_items = sorted(self.barcode_dict.items(), key=itemgetter(1))
barcodes = [b.decode() if isinstance(b, (bytes, bytearray)) else str(b) for b, _ in barcode_items]
n_cells = len(barcodes)
if n_cells:
barcode_ids = np.fromiter((barcode_id for _, barcode_id in barcode_items),
dtype=np.int32,
count=n_cells)
barcode_id_to_row = np.full(int(barcode_ids[-1]) + 1, -1, dtype=np.int32)
barcode_id_to_row[barcode_ids] = np.arange(n_cells, dtype=np.int32)
else:
barcode_id_to_row = np.zeros(0, dtype=np.int32)

regions.sort()
peak_counter = 0

for chrom in sorted(regions.regions.keys()):
if chrom not in self.locations:
continue

regions_c = regions.regions[chrom]
if not regions_c:
continue

n_regions_c = len(regions_c)
peak_starts = np.empty(n_regions_c, dtype=np.int32)
peak_ends = np.empty(n_regions_c, dtype=np.int32)
peak_base = peak_counter
chrom_str = chrom.decode() if isinstance(chrom, (bytes, bytearray)) else str(chrom)
for peak_idx, (start, end) in enumerate(regions_c):
peak_counter += 1
peak_names_append(f"peak_{peak_counter}")
peak_data_append((chrom_str, start, end))
peak_starts[peak_idx] = start
peak_ends[peak_idx] = end

fragment_locs = self.locations[chrom]
if len(fragment_locs) == 0:
continue

fragment_barcodes = self.barcodes[chrom]
frag_starts = fragment_locs['l']
frag_ends = fragment_locs['r']
frag_counts = fragment_locs['c']
frag_rows = barcode_id_to_row[fragment_barcodes]
left_idx = np.searchsorted(peak_ends, frag_starts, side='right').astype(np.int32, copy=False)
right_idx = np.searchsorted(peak_starts, frag_ends, side='left').astype(np.int32, copy=False)
widths = right_idx - left_idx
valid_mask = np.logical_and(frag_rows >= 0, widths > 0)
chunk_nnz = int(widths[valid_mask].sum())
if chunk_nnz:
valid_rows = frag_rows[valid_mask]
valid_counts = frag_counts[valid_mask].astype(np.int32, copy=False)
valid_left = left_idx[valid_mask]
valid_widths = widths[valid_mask]
chunk_rows = np.repeat(valid_rows, valid_widths)
chunk_data = np.repeat(valid_counts, valid_widths)
chunk_offsets = np.empty(valid_widths.shape[0] + 1, dtype=np.int32)
chunk_offsets[0] = 0
np.cumsum(valid_widths, out=chunk_offsets[1:])
repeated_offsets = np.repeat(chunk_offsets[:-1], valid_widths)
repeated_left = np.repeat(valid_left, valid_widths)
intra_offsets = np.arange(chunk_nnz, dtype=np.int32) - repeated_offsets
chunk_cols = peak_base + repeated_left + intra_offsets
row_chunks.append(chunk_rows)
col_chunks.append(chunk_cols)
data_chunks.append(chunk_data)

n_peaks = peak_counter
obs = pd.DataFrame(index=barcodes)
var = pd.DataFrame(peak_data, columns=['chrom', 'start', 'end'], index=peak_names)
if row_chunks:
rows_arr = np.concatenate(row_chunks)
cols_arr = np.concatenate(col_chunks)
data_arr = np.concatenate(data_chunks)
x = sparse.csr_matrix((data_arr, (rows_arr, cols_arr)), shape=(n_cells, n_peaks), dtype=np.int32)
else:
x = sparse.csr_matrix((n_cells, n_peaks), dtype=np.int32)
adata_peaks_loop = ad.AnnData(X=x, obs=obs, var=var)
return adata_peaks_loop

def return_anndata(self, regions):
"""
Build barcode × peak AnnData.

Parameters
----------
regions : MACS3.Signal.Region.Regions
Sorted region collection whose intervals should be excluded.
A merged copy is used so the sweep operates on
non-overlapping or adjacent-collapsed intervals.
"""
merged_regions: Regions
chrom: bytes

merged_regions = Regions()
for chrom in sorted(regions.regions.keys()):
merged_regions.regions[chrom] = regions.regions[chrom][:]
merged_regions.total += len(merged_regions.regions[chrom])
merged_regions.merge_overlap()
return self._two_pointer_sweep(merged_regions)

@cython.ccall
def sample_percent(self,
percent: cython.float,
Expand Down
Loading
Loading