Skip to content

Commit 1ed29bb

Browse files
authored
Merge pull request #120 from databio/dev
Release v0.2.7
2 parents a82fc1a + d3f6d0c commit 1ed29bb

File tree

17 files changed

+196
-117
lines changed

17 files changed

+196
-117
lines changed

bindings/python/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "gtars-py"
3-
version = "0.2.6"
3+
version = "0.2.7"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

bindings/python/py_src/gtars/tokenizers/__init__.pyi

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,12 @@ class Universe:
171171
str: A string describing the Universe.
172172
"""
173173

174-
def create_instances(
175-
sequences: Union[List[int], List[List[int]]],
176-
window_size: int,
177-
algorithm: str,
178-
) -> List[Dict[str, Union[int, List[int]]]]:
174+
def tokenize_fragment_file(file: str, tokenizer: Tokenizer) -> Dict[str, List[int]]:
179175
"""
180-
Creates training instances for a given sequence or list of sequences.
181-
176+
Tokenizes a fragment file using the specified tokenizer.
182177
Args:
183-
sequences (Union[List[int], List[List[int]]]): A sequence or list of sequences of token IDs.
184-
window_size (int): The size of the context window.
185-
algorithm (str): The algorithm to use ('cbow' or 'sg').
186-
178+
file (str): The path to the fragment file.
179+
tokenizer (Tokenizer): The tokenizer to use for tokenization.
187180
Returns:
188-
List[Dict[str, Union[int, List[int]]]]: A list of dictionaries representing the training instances.
189-
"""
181+
Dict[str, List[int]]: A dictionary mapping cell barcodes to lists of token IDs.
182+
"""

bindings/python/src/models/region_set.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ impl PyRegionSet {
164164
Ok(())
165165
}
166166

167-
fn mean_region_width(&self) -> PyResult<u32> {
168-
Ok(self.regionset.mean_region_width())
167+
fn mean_region_width(&self) -> f64 {
168+
let mean_width = self.regionset.mean_region_width();
169+
mean_width
169170
}
170171
}

bindings/python/src/tokenizers/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ mod utils;
66
use pyo3::prelude::*;
77

88
use crate::tokenizers::py_tokenizers::PyTokenizer;
9-
use crate::tokenizers::utils::py_create_instances;
9+
use crate::tokenizers::utils::py_tokenize_fragment_file;
1010
// use crate::tokenizers::universe::PyUniverse;
1111
// use crate::tokenizers::encoding::{PyBatchEncoding, PyEncoding};
1212

1313
#[pymodule]
1414
pub fn tokenizers(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
1515
m.add_class::<PyTokenizer>()?;
16-
m.add_wrapped(wrap_pyfunction!(py_create_instances))?;
16+
m.add_wrapped(wrap_pyfunction!(py_tokenize_fragment_file))?;
1717
Ok(())
1818
}

bindings/python/src/tokenizers/py_tokenizers/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,9 @@ impl PyTokenizer {
292292
})
293293
}
294294
}
295+
296+
impl PyTokenizer {
297+
pub fn inner(&self) -> &Tokenizer {
298+
&self.tokenizer
299+
}
300+
}
Lines changed: 14 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,18 @@
1+
use pyo3::exceptions::PyRuntimeError;
12
use pyo3::prelude::*;
2-
use pyo3::types::PyDict;
3+
use pyo3::types::{IntoPyDict, PyDict};
34

4-
use rayon::prelude::*;
5+
use super::PyTokenizer;
6+
use gtars::tokenizers::utils::fragments::tokenize_fragment_file;
57

6-
use gtars::tokenizers::utils::r2v::{create_instances, Algorithm, Instance};
7-
8-
#[pyfunction(name = "create_instances")]
9-
pub fn py_create_instances(
10-
sequences: &Bound<'_, PyAny>,
11-
window_size: usize,
12-
algorithm: &str,
13-
) -> PyResult<Vec<Py<PyDict>>> {
14-
Python::with_gil(|py| {
15-
let algorithm = match algorithm {
16-
"cbow" => Algorithm::Cbow,
17-
"sg" => Algorithm::Sg,
18-
_ => return Err(pyo3::exceptions::PyValueError::new_err("Invalid algorithm")),
19-
};
20-
21-
if let Ok(sequence) = sequences.extract::<Vec<u32>>() {
22-
let result = create_instances(&sequence, window_size, algorithm);
23-
let mapped_dicts = result
24-
.into_iter()
25-
.map(|instance| {
26-
let dict = PyDict::new_bound(py);
27-
match instance {
28-
Instance::Cbow {
29-
context_ids,
30-
target_id,
31-
} => {
32-
dict.set_item("context_ids", context_ids).unwrap();
33-
dict.set_item("target_id", target_id).unwrap();
34-
}
35-
Instance::Sg {
36-
center_id,
37-
context_ids,
38-
} => {
39-
dict.set_item("center_id", center_id).unwrap();
40-
dict.set_item("context_ids", context_ids).unwrap();
41-
}
42-
}
43-
dict.into()
44-
})
45-
.collect::<Vec<Py<PyDict>>>();
46-
Ok(mapped_dicts)
47-
} else if let Ok(sequences) = sequences.extract::<Vec<Vec<u32>>>() {
48-
let result: Vec<Vec<Instance>> = sequences
49-
.par_iter()
50-
.map(|sequence| create_instances(sequence, window_size, algorithm))
51-
.collect();
52-
53-
let mapped_dicts = result
54-
.into_iter()
55-
.flat_map(|instances| {
56-
instances.into_iter().map(|instance| {
57-
let dict = PyDict::new_bound(py);
58-
match instance {
59-
Instance::Cbow {
60-
context_ids,
61-
target_id,
62-
} => {
63-
dict.set_item("context_ids", context_ids).unwrap();
64-
dict.set_item("target_id", target_id).unwrap();
65-
}
66-
Instance::Sg {
67-
center_id,
68-
context_ids,
69-
} => {
70-
dict.set_item("center_id", center_id).unwrap();
71-
dict.set_item("context_ids", context_ids).unwrap();
72-
}
73-
}
74-
dict.into()
75-
})
76-
})
77-
.collect::<Vec<Py<PyDict>>>();
78-
return Ok(mapped_dicts);
79-
} else {
80-
return Err(pyo3::exceptions::PyValueError::new_err(
81-
"Invalid input type. Must be a sequence or list of sequences.",
82-
));
83-
}
84-
})
8+
#[pyfunction(name = "tokenize_fragment_file")]
9+
pub fn py_tokenize_fragment_file(file: String, tokenizer: &PyTokenizer) -> PyResult<Py<PyDict>> {
10+
let res = tokenize_fragment_file(&file, tokenizer.inner());
11+
match res {
12+
Ok(res) => Python::with_gil(|py| {
13+
let py_dict = res.into_py_dict_bound(py);
14+
Ok(py_dict.into())
15+
}),
16+
Err(res) => Err(PyRuntimeError::new_err(res.to_string())),
17+
}
8518
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
from gtars.models import RegionSet
7+
8+
class TestRegionSet:
9+
10+
@pytest.mark.parametrize(
11+
"bed_file",
12+
[
13+
"https://raw.githubusercontent.com/databio/gtars/refs/heads/master/gtars/tests/data/regionset/dummy.narrowPeak",
14+
],
15+
)
16+
def test_mean_region_width(self, bed_file):
17+
18+
rs = RegionSet(bed_file)
19+
20+
assert rs.mean_region_width() == 4.22

bindings/python/tests/test_tokenizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def test_decode_tokens():
213213
assert decoded == ["chr9:3526071-3526165"]
214214

215215

216+
@pytest.mark.skip(reason="Needs to be fixed")
216217
def test_special_tokens_mask():
217218
cfg_path = os.path.join(TEST_DATA_DIR, "tokenizers", "peaks.scored.bed")
218219
tokenizer = Tokenizer(cfg_path)

bindings/r/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: gtars
22
Title: Performance critical genomic interval analysis using Rust, in R
3-
Version: 0.2.5
3+
Version: 0.2.7
44
Authors@R:
55
person("Nathan", "LeRoy", , "[email protected]", role = c("aut", "cre"),
66
comment = c(ORCID = "0000-0002-7354-7213"))

bindings/r/src/rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = 'gtars-r'
3-
version = '0.2.6'
3+
version = '0.2.7'
44
edition = '2021'
55

66
[lib]

bindings/r/src/rust/src/tokenizers.rs

Whitespace-only changes.

gtars/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "gtars"
3-
version = "0.2.6"
3+
version = "0.2.7"
44
edition = "2021"
55
description = "Performance-critical tools to manipulate, analyze, and process genomic interval data. Primarily focused on building tools for geniml - our genomic machine learning python package."
66
license = "MIT"

gtars/docs/changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.2.7]
8+
- added utility function to tokenize fragment files
9+
- fixed [#119](https://github.com/databio/gtars/issues/119)
10+
- fixed [#121](https://github.com/databio/gtars/issues/121)
11+
712
## [0.2.6]
813
- Fixed Iterator bug in RegionSet Python bindings [#116](https://github.com/databio/gtars/issues/116)
914
- Added caching of identifier in RegionSet in Python bindings

gtars/src/common/models/region_set.rs

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,62 @@ impl TryFrom<&Path> for RegionSet {
5656

5757
let mut header: String = String::new();
5858

59+
let mut first_line: bool = true;
60+
5961
for line in reader.lines() {
6062
let string_line = line?;
6163

6264
let parts: Vec<String> = string_line.split('\t').map(|s| s.to_string()).collect();
6365

64-
if parts.len() < 3 {
65-
if string_line.starts_with("browser")
66-
| string_line.starts_with("track")
67-
| string_line.starts_with("#")
68-
{
69-
header.push_str(&string_line);
70-
}
66+
if string_line.starts_with("browser")
67+
| string_line.starts_with("track")
68+
| string_line.starts_with("#")
69+
{
70+
header.push_str(&string_line);
71+
first_line = false;
7172
continue;
7273
}
7374

75+
// Handling column headers like `chr start end etc` without #
76+
if first_line {
77+
if parts.len() >= 3 {
78+
let is_header: bool = match parts[1].parse::<u32>() {
79+
Ok(_num) => false,
80+
Err(_) => true,
81+
};
82+
if is_header {
83+
header.push_str(&string_line);
84+
first_line = false;
85+
continue;
86+
}
87+
}
88+
first_line = false;
89+
}
90+
7491
new_regions.push(Region {
7592
chr: parts[0].to_owned(),
7693

7794
// To ensure that lines are regions, and we can parse it, we are using Result matching
78-
// And it helps to skip lines that are headers.
79-
start: parts[1].parse()?,
80-
end: parts[2].parse()?,
95+
start: match parts[1].parse() {
96+
Ok(start) => start,
97+
Err(_err) => {
98+
return Err(Error::new(
99+
ErrorKind::Other,
100+
format!("Error in parsing start position: {:?}", parts),
101+
)
102+
.into())
103+
}
104+
},
105+
end: match parts[2].parse() {
106+
Ok(end) => end,
107+
Err(_err) => {
108+
return Err(Error::new(
109+
ErrorKind::Other,
110+
format!("Error in parsing end position: {:?}", parts),
111+
)
112+
.into())
113+
}
114+
},
81115
rest: Some(parts[3..].join("\t")).filter(|s| !s.is_empty()),
82116
});
83117
}
@@ -391,18 +425,16 @@ impl RegionSet {
391425
false
392426
}
393427

394-
pub fn mean_region_width(&self) -> u32 {
395-
if self.is_empty() {
396-
return 0;
397-
}
428+
pub fn mean_region_width(&self) -> f64 {
398429
let sum: u32 = self
399430
.regions
400431
.iter()
401432
.map(|region| region.end - region.start)
402433
.sum();
403434
let count: u32 = self.regions.len() as u32;
404435

405-
sum / count
436+
// must be f64 because python doesn't understand f32
437+
((sum as f64 / count as f64) * 100.0).round() / 100.0
406438
}
407439

408440
///
@@ -542,4 +574,17 @@ mod tests {
542574
assert_eq!(region_set.file_digest(), "6224c4d40832b3e0889250f061e01120");
543575
assert_eq!(region_set.identifier(), "f0b2cf73383b53bd97ff525a0380f200")
544576
}
577+
578+
#[test]
579+
fn test_mean_region_width() {
580+
let file_path = get_test_path("dummy.narrowPeak").unwrap();
581+
let region_set = RegionSet::try_from(file_path.to_str().unwrap()).unwrap();
582+
583+
assert_eq!(region_set.mean_region_width(), 4.22)
584+
}
585+
#[test]
586+
fn test_open_file_with_incorrect_headers() {
587+
let file_path = get_test_path("dummy_incorrect_headers.bed").unwrap();
588+
let region_set = RegionSet::try_from(file_path.to_str().unwrap()).unwrap();
589+
}
545590
}

0 commit comments

Comments
 (0)