Skip to content

Commit bf93c91

Browse files
Merge pull request #104 from ArnovanHilten/epistasis-interpretation
Interpretation update - Adds the interpretation module with six interpretation methods - get_weight_scores: uses the weights to calculate the importance of each feature and node - DeepExplain: uses the gradient (see DeepExplain) to calculate the importance - RLIPP: uses logistic regression with signals to and from the node to calculate a measure of non-linearity for all nodes - NID: Finds interacting features based on the features with the strongest weights - DFIM: perturbs each input (or N inputs in the order of importance), and tracks which other features change importance to find interacting features - PathExplain: Uses the Expected Hessian to find interacting features - bugfixes for converting plink2 files - module for converting toplogy to npz matrices - add one-hot encoding support
2 parents a911822 + 06f0122 commit bf93c91

35 files changed

+2529
-490
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
runs-on: ubuntu-latest
99
strategy:
1010
matrix:
11-
python-version: ["3.7", "3.8"]
11+
python-version: ["3.7", "3.8", "3.9"]
1212

1313
steps:
1414
- uses: actions/checkout@v3

GenNet.py

+63-7
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,33 @@
66
import argparse
77

88
sys.path.insert(1, os.path.dirname(os.getcwd()) + "/GenNet_utils/")
9-
from GenNet_utils.Create_plots import plot
10-
from GenNet_utils.Train_network import train_classification, train_regression
11-
from GenNet_utils.Convert import convert
12-
from GenNet_utils.Topology import topology
139

1410

1511
def main():
1612
args = ArgumentParser().parse_cmd_args()
1713

1814
if args.mode == 'train':
1915
if args.problem_type == "classification":
20-
train_classification(args)
16+
args.regression = False
2117
elif args.problem_type == "regression":
22-
train_regression(args)
18+
args.regression = True
2319
else:
2420
print('something went wrong invalid problem type', args.problem_type)
21+
from GenNet_utils.Train_network import train_model
22+
train_model(args)
23+
2524
elif args.mode == "plot":
25+
from GenNet_utils.Create_plots import plot
2626
plot(args)
2727
if args.mode == 'convert':
28+
from GenNet_utils.Convert import convert
2829
convert(args)
2930
if args.mode == "topology":
31+
from GenNet_utils.Topology import topology
3032
topology(args)
33+
if args.mode == "interpret":
34+
from GenNet_utils.Interpret import interpret
35+
interpret(args)
3136

3237

3338
class ArgumentParser():
@@ -51,6 +56,9 @@ def __init__(self):
5156
parser_topology = subparsers.add_parser("topology", help="Create standard topology files")
5257
self.make_parser_topology(parser_topology)
5358

59+
parser_interpret = subparsers.add_parser("interpret", help="Post-hoc interpretation analysis on the network")
60+
self.make_parser_interpret(parser_interpret)
61+
5462
self.parser = parser
5563

5664
def parse_cmd_args(self):
@@ -239,7 +247,11 @@ def make_parser_train(self, parser_train):
239247
action='store_true',
240248
default=False,
241249
help='Flag for one hot encoding as a first layer in the network')
242-
250+
parser_train.add_argument(
251+
"-init_linear",
252+
action='store_true',
253+
default=False,
254+
help='initialize the one-hot encoding for the neural network with a linear assumption')
243255
return parser_train
244256

245257
def make_parser_plot(self, parser_plot):
@@ -298,5 +310,49 @@ def make_parser_topology(self, parser_topology):
298310
return parser_topology
299311

300312

313+
314+
def make_parser_interpret(self, parser_topology):
315+
parser_topology.add_argument(
316+
"-type",
317+
default='get_weight_scores', type=str,
318+
choices=['get_weight_scores', 'NID', 'RLIPP', 'DFIM',"PathExplain","DeepExplain"],
319+
help="choose interpretation method, choice")
320+
parser_topology.add_argument(
321+
"-resultpath",
322+
type=str,
323+
required=True,
324+
help="Path to the folder with the trained network (resultfolder) ")
325+
parser_topology.add_argument(
326+
'-layer',
327+
type=int,
328+
required=False,
329+
help='Select a layer for interpretation only necessary for NID')
330+
parser_topology.add_argument(
331+
'-num_eval',
332+
type=int,
333+
required=False,
334+
default = 100,
335+
help='Select the number of SNPs to eval in DFIM')
336+
parser_topology.add_argument(
337+
'-start_rank',
338+
type=int,
339+
required=False,
340+
default = 0,
341+
help='Multiprocessing, start from Nth ranked important variant')
342+
parser_topology.add_argument(
343+
'-end_rank',
344+
type=int,
345+
required=False,
346+
default = 0,
347+
help='Multiprocessing, stop at Nth ranked important SNP')
348+
parser_topology.add_argument(
349+
'-num_sample_pat',
350+
type=int,
351+
required=False,
352+
default = 1000,
353+
help='Select a number of patients to sample for DFIM')
354+
return parser_topology
355+
356+
301357
if __name__ == '__main__':
302358
main()

GenNet_utils/Convert.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def merge_hdf5_hase(args):
7575
f = tables.open_file(args.outfolder + args.study_name + '_step2_merged_genotype.h5', mode='a')
7676
for i in tqdm.tqdm(range(number_of_files)):
7777
gen_tmp = h5py.File(filepath_hase.format(i), 'r')['genotype']
78-
f.root.data.append(np.array(np.round(gen_tmp[:, :]), dtype=np.int))
78+
f.root.data.append(np.array(np.round(gen_tmp[:, :]), dtype=int))
7979
f.close()
8080

8181
args.outfolder = args.genotype
@@ -365,7 +365,7 @@ def merge_transpose(args):
365365
print("chunking is not necessary")
366366
for job_n in tqdm.tqdm(range(args.n_jobs)):
367367
gen_tmp = tables.open_file(args.genotype + args.study_name + '_step5_genotype_transposed_' + str(job_n) + '.h5', mode='r')
368-
f.root.data.append(np.array(np.round(gen_tmp.root.data[:, :]), dtype=np.int))
368+
f.root.data.append(np.array(np.round(gen_tmp.root.data[:, :]), dtype=int))
369369
gen_tmp.close()
370370
f.close()
371371
else:
@@ -375,7 +375,7 @@ def merge_transpose(args):
375375
for chunckblock in range(int(np.ceil(gen_tmp.root.data.shape[0] / chunk))):
376376
begins = chunckblock * chunk
377377
tills = min(((chunckblock + 1) * chunk), gen_tmp.root.data.shape[0])
378-
f.root.data.append(np.array(np.round(gen_tmp.root.data[begins:tills, :]), dtype=np.int))
378+
f.root.data.append(np.array(np.round(gen_tmp.root.data[begins:tills, :]), dtype=int))
379379
gen_tmp.close()
380380
f.close()
381381
print("completed")

GenNet_utils/Convert_topology_npz.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import argparse
2+
import numpy as np
3+
import pandas as pd
4+
import os
5+
import argparse
6+
import numpy as np
7+
import pandas as pd
8+
import os
9+
from scipy import sparse
10+
11+
12+
def main():
13+
"""
14+
args:
15+
snp: the name of the column in the topology.csv dataset with the ID for the SNP column
16+
gene: the name of the column in the topology.csv dataset with the ID for the gene column
17+
direc: (Optional) the directory where the topology.csv file is located, if omitted it takes the current directory
18+
file_name: (Optional) the name of the file to save as, defaults to "SNP_gene_mask"
19+
20+
Return: SNP_gene_mask.npz, the .npz file corresponding to the topology.csv
21+
"""
22+
parser = argparse.ArgumentParser(description="A simple script with command-line arguments")
23+
parser.add_argument("--snp", help="Your snp", required=True)
24+
parser.add_argument("--gene", help="Your gene", required=True)
25+
parser.add_argument("--direc", help="Your Directory", required=False)
26+
parser.add_argument("--file_name", help="Your file name", default="SNP_gene_mask", required=False)
27+
args = parser.parse_args()
28+
29+
if args.direc:
30+
try:
31+
os.chdir(args.direc)
32+
print(f"Navigated to directory: {os.getcwd()}")
33+
except FileNotFoundError:
34+
print(f"Directory '{args.direc}' not found.")
35+
36+
snp_level = args.snp
37+
gene_level = args.gene
38+
topology = pd.read_csv("topology.csv")
39+
data = np.ones(len(topology), np.bool)
40+
coord = (topology[snp_level].values, topology[gene_level].values)
41+
SNP_gene_matrix = sparse.coo_matrix(((data), coord), shape=(topology[snp_level].max()+1, topology[gene_level].max()+1))
42+
file_name = args.file_name
43+
sparse.save_npz(file_name, SNP_gene_matrix)
44+
45+
if __name__ == "__main__":
46+
main()
47+
48+

GenNet_utils/Create_network.py

+73-5
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def one_hot_input(input_layer):
9090
def add_covariates(model, input_cov, num_covariates, regression, negative_values_ytrain, mean_ytrain, l1_value, L1_act):
9191
if num_covariates > 0:
9292
model = activation_layer(model, regression, negative_values_ytrain)
93-
model = K.layers.concatenate([model, input_cov], axis=1)
94-
model = K.layers.BatchNormalization(center=False, scale=False)(model)
93+
model = K.layers.concatenate([model, input_cov], axis=1, name="concatenate_cov")
94+
model = K.layers.BatchNormalization(center=False, scale=False, name="batchnorm_cov")(model)
9595
model = K.layers.Dense(units=1, name="output_layer_cov",
9696
kernel_regularizer=tf.keras.regularizers.l1(l=l1_value),
9797
activity_regularizer=K.regularizers.l1(L1_act),
@@ -232,7 +232,7 @@ def create_network_from_csv(datapath,
232232
model = K.layers.Reshape(input_shape=(inputsize,), target_shape=(inputsize, 1))(input_layer)
233233

234234
for i in range(len(columns) - 1):
235-
matrix_ones = np.ones(len(network_csv[[columns[i], columns[i + 1]]]), np.bool)
235+
matrix_ones = np.ones(len(network_csv[[columns[i], columns[i + 1]]]), bool)
236236
matrix_coord = (network_csv[columns[i]].values, network_csv[columns[i + 1]].values)
237237
if i == 0:
238238
matrixshape = (inputsize, network_csv[columns[i + 1]].max() + 1)
@@ -331,7 +331,7 @@ def gene_network_multiple_filters(datapath,
331331
mean_ytrain = 0
332332
negative_values_ytrain = False
333333

334-
print("height_multiple_filters with", filters, "filters")
334+
print("gene_network_multiple_filters with", filters, "filters")
335335

336336
masks = []
337337
for npz_path in glob.glob(datapath + '/*.npz'):
@@ -482,4 +482,72 @@ def regression_height(inputsize, num_covariates=2, l1_value=0.001):
482482
print(model.summary())
483483

484484
return model, masks
485-
485+
486+
487+
488+
def remove_batchnorm_model(model, masks, keep_cov = False):
489+
original_model = model
490+
inputs = tf.keras.Input(shape=original_model.input_shape[0][1:])
491+
x = inputs
492+
493+
mask_num = 0
494+
for layer in original_model.layers[1:]:
495+
# Skip BatchNormalization layers
496+
if not isinstance(layer, tf.keras.layers.BatchNormalization):
497+
# Handle LocallyDirected1D layer with custom arguments
498+
if isinstance(layer, LocallyDirected1D):
499+
config = layer.get_config()
500+
new_layer = LocallyDirected1D(filters=config['filters'],
501+
mask=masks[mask_num],
502+
name=config['name'])
503+
x = new_layer(x)
504+
mask_num = mask_num + 1
505+
elif "_cov" in layer.name and not keep_cov:
506+
pass
507+
else:
508+
# Add other layers as they are
509+
x = layer.__class__.from_config(layer.get_config())(x)
510+
511+
# Create the new model
512+
new_model = tf.keras.Model(inputs=inputs, outputs=x)
513+
514+
original_model_layers = [x for x in original_model.layers if not isinstance(x, tf.keras.layers.BatchNormalization)]
515+
516+
for new_layer, layer in zip(new_model.layers, original_model_layers):
517+
new_layer.set_weights(layer.get_weights())
518+
519+
print(new_model.summary())
520+
521+
return new_model
522+
523+
524+
def remove_cov(model, masks):
525+
original_model = model
526+
inputs = tf.keras.Input(shape=original_model.input_shape[0][1:])
527+
x = inputs
528+
529+
mask_num = 0
530+
for layer in original_model.layers[1:]:
531+
# Skip BatchNormalization layers
532+
if isinstance(layer, LocallyDirected1D):
533+
config = layer.get_config()
534+
new_layer = LocallyDirected1D(filters=config['filters'],
535+
mask=masks[mask_num],
536+
name=config['name'])
537+
x = new_layer(x)
538+
mask_num = mask_num + 1
539+
elif "_cov" in layer.name:
540+
pass
541+
else:
542+
# Add other layers as they are
543+
x = layer.__class__.from_config(layer.get_config())(x)
544+
545+
# Create the new model
546+
new_model = tf.keras.Model(inputs=inputs, outputs=x)
547+
548+
for new_layer, layer in zip(new_model.layers, original_model.layers ):
549+
new_layer.set_weights(layer.get_weights())
550+
551+
print(new_model.summary())
552+
553+
return new_model

GenNet_utils/Dataloader.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,6 @@ def get_labels(datapath, set_number):
9292
return ybatch
9393

9494

95-
def get_data(datapath, genotype_path, set_number):
96-
print("depreciated")
97-
groundtruth = pd.read_csv(datapath + "/subjects.csv")
98-
h5file = tables.open_file(genotype_path + "genotype.h5", "r")
99-
groundtruth = groundtruth[groundtruth["set"] == set_number]
100-
xbatchid = np.array(groundtruth["genotype_row"].values, dtype=np.int64)
101-
xbatch = h5file.root.data[xbatchid, :]
102-
ybatch = np.reshape(np.array(groundtruth["labels"].values), (-1, 1))
103-
h5file.close()
104-
return xbatch, ybatch
105-
106-
107-
10895

10996

11097
class TrainDataGenerator(K.utils.Sequence):
@@ -160,7 +147,7 @@ def single_genotype_matrix(self, idx):
160147
ybatch = self.training_subjects["labels"].iloc[batchindexes]
161148
xcov = self.training_subjects.filter(like="cov_").iloc[batchindexes]
162149
xcov = xcov.values
163-
xbatchid = np.array(self.training_subjects["genotype_row"].iloc[batchindexes], dtype=np.int64)
150+
xbatchid = np.array(self.training_subjects["genotype_row"].iloc[batchindexes], dtype=int)
164151
xbatch = genotype_hdf.root.data[xbatchid, :]
165152
xbatch = self.if_one_hot(xbatch)
166153
ybatch = np.reshape(np.array(ybatch), (-1, 1))
@@ -181,7 +168,7 @@ def multi_genotype_matrix(self, idx):
181168
for i in subjects_current_batch["chunk_id"].unique():
182169
genotype_hdf = tables.open_file(self.genotype_path + "/" + str(i) + self.h5filenames + ".h5", "r")
183170
subjects_current_chunk = subjects_current_batch[subjects_current_batch["chunk_id"] == i]
184-
xbatchid = np.array(subjects_current_chunk["genotype_row"].values, dtype=np.int64)
171+
xbatchid = np.array(subjects_current_chunk["genotype_row"].values, dtype=int)
185172
if len(xbatchid) > 1:
186173
pass
187174
else:
@@ -252,21 +239,20 @@ def if_one_hot(self, xbatch):
252239
else:
253240
print("unexpected shape!")
254241
return xbatch
255-
242+
256243
def single_genotype_matrix(self, idx):
257244
genotype_hdf = tables.open_file(self.genotype_path + "/genotype.h5", "r")
258245
ybatch = self.eval_subjects["labels"].iloc[idx * self.batch_size:((idx + 1) * self.batch_size)]
259246
xcov = self.eval_subjects.filter(like="cov_").iloc[idx * self.batch_size:((idx + 1) * self.batch_size)]
260247
xcov = xcov.values
261248
xbatchid = np.array(self.eval_subjects["genotype_row"].iloc[idx * self.batch_size:((idx + 1) * self.batch_size)],
262-
dtype=np.int64)
249+
dtype=int)
263250
xbatch = genotype_hdf.root.data[xbatchid, :]
264251
xbatch = self.if_one_hot(xbatch)
265252
ybatch = np.reshape(np.array(ybatch), (-1, 1))
266253
genotype_hdf.close()
267254
return [xbatch, xcov], ybatch
268255

269-
270256
def multi_genotype_matrix(self, idx):
271257
subjects_current_batch = self.eval_subjects.iloc[idx * self.batch_size:((idx + 1) * self.batch_size)]
272258
subjects_current_batch["batch_index"] = np.arange(subjects_current_batch.shape[0])
@@ -276,7 +262,7 @@ def multi_genotype_matrix(self, idx):
276262
for i in subjects_current_batch["chunk_id"].unique():
277263
genotype_hdf = tables.open_file(self.genotype_path + "/" + str(i) + self.h5filenames + ".h5", "r")
278264
subjects_current_chunk = subjects_current_batch[subjects_current_batch["chunk_id"] == i]
279-
xbatchid = np.array(subjects_current_chunk["genotype_row"].values, dtype=np.int64)
265+
xbatchid = np.array(subjects_current_chunk["genotype_row"].values, dtype=int)
280266
xbatch[subjects_current_chunk["batch_index"].values, :] = genotype_hdf.root.data[xbatchid, :]
281267
genotype_hdf.close()
282268

@@ -286,5 +272,22 @@ def multi_genotype_matrix(self, idx):
286272
return [xbatch, xcov], ybatch
287273

288274

275+
def get_data(self, sample_pat=0):
276+
277+
genotype_hdf = tables.open_file(self.genotype_path + "/genotype.h5", "r")
278+
ybatch = self.eval_subjects["labels"]
279+
280+
if sample_pat > 0:
281+
self.eval_subjects = self.eval_subjects.sample(n=sample_pat, random_state=1)
282+
283+
xbatchid = np.array(self.eval_subjects["genotype_row"].values, dtype=int)
284+
285+
xcov = self.eval_subjects.filter(like="cov_")
286+
xcov = xcov.values
287+
xbatch = genotype_hdf.root.data[xbatchid,...]
288+
xbatch = self.if_one_hot(xbatch)
289+
ybatch = np.reshape(np.array(ybatch), (-1, 1))
290+
genotype_hdf.close()
291+
return [xbatch, xcov], ybatch
289292

290293

0 commit comments

Comments
 (0)