Skip to content

Commit 16cb5bf

Browse files
authored
Merge pull request #1845 from anuprulez/dna_encoder
Adds one-hot and k-mer encoder for DNA sequences
2 parents b3855f9 + 7cb1dbb commit 16cb5bf

8 files changed

Lines changed: 466 additions & 96 deletions

tools/sklearn/main_macros.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<macros>
2-
<token name="@VERSION@">1.0.11.1</token>
2+
<token name="@VERSION@">1.0.11.2</token>
33
<token name="@PROFILE@">24.2</token>
44

55
<xml name="python_requirements">
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
>1
2+
sequences
3+
>2
4+
ACGT
5+
>3
6+
TGCAA
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
>seq1
2+
ACGT
3+
>seq2
4+
TGCAA
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
>seq1
2+
ACGT
3+
>seq2
4+
TGCA

tools/sklearn/test-data/ohe_out_4.tabular

Lines changed: 0 additions & 8 deletions
This file was deleted.

tools/sklearn/test-data/ohe_out_5.tabular

Lines changed: 0 additions & 8 deletions
This file was deleted.

tools/sklearn/to_categorical.py

Lines changed: 150 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,176 @@
11
import argparse
22
import json
3+
import re
34
import warnings
45

6+
import h5py
57
import numpy as np
68
import pandas as pd
7-
from keras.utils import to_categorical
89

10+
warnings.simplefilter("ignore")
911

10-
def main(inputs, infile, outfile, num_classes=None):
11-
"""
12-
Parameter
13-
---------
14-
input : str
15-
File path to galaxy tool parameter
1612

17-
infile : str
18-
File paths of input vector
13+
def _get_longest_sequence_length(fasta_file):
14+
max_len = 0
15+
max_id = None
16+
for name in fasta_file.keys():
17+
seq_len = len(fasta_file[name])
18+
if seq_len > max_len:
19+
max_len = seq_len
20+
max_id = name
1921

20-
outfile : str
21-
File path to output matrix
22+
return max_len, max_id
2223

23-
num_classes : str
24-
Total number of classes. If None, this would be inferred as the (largest number in y) + 1
2524

26-
"""
27-
warnings.simplefilter("ignore")
25+
def encode_dna_sequences(fasta_path, padding, outfile, outfile_matrix):
26+
from galaxy_ml.preprocessors import GenomeOneHotEncoder
27+
import pyfaidx
2828

29-
with open(inputs, "r") as param_handler:
30-
params = json.load(param_handler)
29+
seq_length = None
30+
fasta_file = pyfaidx.Fasta(fasta_path)
31+
if padding:
32+
seq_length, max_id = _get_longest_sequence_length(fasta_file)
33+
print("Longest sequence is %s with length %d" % (max_id, seq_length))
34+
print("Padding: {}".format(padding))
35+
X = np.arange(len(fasta_file.keys())).reshape(-1, 1)
36+
genome_encoder = GenomeOneHotEncoder(
37+
fasta_path=fasta_path, seq_length=seq_length, padding=padding
38+
)
39+
genome_encoder.fit(X)
40+
encoded_dna_sequences = genome_encoder.transform(X)
41+
flatted_enc_seqs = encoded_dna_sequences.flatten().reshape(
42+
encoded_dna_sequences.shape[0], -1
43+
)
44+
np.savetxt(
45+
outfile, np.asarray(flatted_enc_seqs, dtype=int), fmt="%d", delimiter="\t"
46+
)
47+
with h5py.File(outfile_matrix, "w") as handle:
48+
handle.create_dataset("data", data=encoded_dna_sequences, compression="gzip")
3149

32-
input_header = params["header0"]
33-
header = "infer" if input_header else None
3450

35-
input_vector = pd.read_csv(infile, sep="\t", header=header)
51+
def seq_to_kmers(sequence, k=3):
52+
return [sequence[idx: idx + k] for idx in range(len(sequence) - k + 1)]
53+
54+
55+
def normalize_dna_sequence(sequence):
56+
return re.sub(r"\s+", "", sequence.upper())
57+
58+
59+
def is_valid_dna_kmer(kmer):
60+
valid_dna_chars = set("ACGTRYSWKMBDHVN")
61+
return set(kmer).issubset(valid_dna_chars)
62+
63+
64+
def build_kmer_vocabulary(sequences, k):
65+
vocabulary = {"<PAD>": 0, "<UNK>": 1}
66+
for sequence in sequences:
67+
for kmer in seq_to_kmers(sequence, k):
68+
if is_valid_dna_kmer(kmer) and kmer not in vocabulary:
69+
vocabulary[kmer] = len(vocabulary)
70+
71+
if len(vocabulary) == 2:
72+
raise ValueError(
73+
"No DNA k-mers were generated. Check that k is not longer than all sequences."
74+
)
75+
76+
return vocabulary
77+
78+
79+
def encode_sequence_kmers(sequence, vocabulary, k):
80+
return [
81+
vocabulary.get(kmer, vocabulary["<UNK>"])
82+
for kmer in seq_to_kmers(sequence, k)
83+
if is_valid_dna_kmer(kmer)
84+
]
85+
3686

87+
def pad_encoded_sequences(encoded_sequences, pad_value=0):
88+
max_len = max(len(sequence) for sequence in encoded_sequences)
89+
return [
90+
sequence + [pad_value] * (max_len - len(sequence))
91+
for sequence in encoded_sequences
92+
]
93+
94+
95+
def encode_dna_kmers(fasta_path, k, outfile, outfile_vocab):
96+
import pyfaidx
97+
98+
if k < 1:
99+
raise ValueError("k-mer size must be at least 1.")
100+
101+
fasta_file = pyfaidx.Fasta(fasta_path)
102+
sequences = [
103+
normalize_dna_sequence(str(fasta_file[name])) for name in fasta_file.keys()
104+
]
105+
vocabulary = build_kmer_vocabulary(sequences, k)
106+
encoded_sequences = [
107+
encode_sequence_kmers(sequence, vocabulary, k) for sequence in sequences
108+
]
109+
padded_sequences = np.asarray(
110+
pad_encoded_sequences(encoded_sequences, pad_value=vocabulary["<PAD>"]),
111+
dtype=int,
112+
)
113+
np.savetxt(outfile, padded_sequences, fmt="%d", delimiter="\t")
114+
with open(outfile_vocab, "w") as handle:
115+
json.dump(vocabulary, handle, indent=4, sort_keys=False)
116+
handle.write("\n")
117+
118+
119+
def encode_labels(infile, input_header, outfile, num_classes=None):
120+
from keras.utils import to_categorical
121+
122+
header = "infer" if input_header else None
123+
input_vector = pd.read_csv(infile, sep="\t", header=header)
37124
output_matrix = to_categorical(input_vector, num_classes=num_classes)
125+
np.savetxt(outfile, np.asarray(output_matrix, dtype=int), fmt="%d", delimiter="\t")
126+
127+
128+
def main(args):
129+
task_type = args.encoder_task_type
130+
num_classes = args.num_classes
131+
header = "infer" if args.labels_header == "booltrue" else None
132+
padding = True if args.padding == "booltrue" else False
133+
sequence_encoding = args.sequence_encoding
134+
kmer_size = args.kmer_size
38135

39-
np.savetxt(outfile, output_matrix, fmt="%d", delimiter="\t")
136+
if task_type == "label_encoder":
137+
encode_labels(args.labels_path, header, args.outfile, num_classes=num_classes)
138+
elif task_type == "dna_encoder":
139+
if sequence_encoding == "one_hot":
140+
encode_dna_sequences(
141+
args.fasta_path, padding, args.outfile, args.outfile_matrix
142+
)
143+
elif sequence_encoding == "kmer":
144+
encode_dna_kmers(
145+
args.fasta_path, kmer_size, args.outfile, args.outfile_vocab
146+
)
147+
else:
148+
raise ValueError(
149+
"Unsupported DNA sequence encoding: %s" % sequence_encoding
150+
)
151+
else:
152+
raise ValueError("Unsupported encoder type: %s" % task_type)
40153

41154

42155
if __name__ == "__main__":
43156
aparser = argparse.ArgumentParser()
44-
aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
45-
aparser.add_argument("-y", "--infile", dest="infile")
157+
aparser.add_argument("-l", "--labels_path", dest="labels_path")
158+
aparser.add_argument("-d", "--labels_header", dest="labels_header", default=False)
159+
aparser.add_argument(
160+
"-t", "--encoder_task_type", dest="encoder_task_type", required=True
161+
)
162+
aparser.add_argument(
163+
"-y", "--num_classes", dest="num_classes", type=int, default=None
164+
)
165+
aparser.add_argument("-p", "--padding", dest="padding", default="boolfalse")
46166
aparser.add_argument(
47-
"-n", "--num_classes", dest="num_classes", type=int, default=None
167+
"-s", "--sequence_encoding", dest="sequence_encoding", default="one_hot"
48168
)
49-
aparser.add_argument("-o", "--outfile", dest="outfile")
169+
aparser.add_argument("-k", "--kmer_size", dest="kmer_size", type=int, default=3)
170+
aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
171+
aparser.add_argument("-o", "--outfile", dest="outfile", required=True)
172+
aparser.add_argument("-m", "--outfile_matrix", dest="outfile_matrix")
173+
aparser.add_argument("-v", "--outfile_vocab", dest="outfile_vocab")
50174
args = aparser.parse_args()
51175

52-
main(args.inputs, args.infile, args.outfile, args.num_classes)
176+
main(args)

0 commit comments

Comments
 (0)