Skip to content

Commit 380f19a

Browse files
authored
add JSON export to CLI (#23)
* add export of ORA and file overwrite prompt * add output to each method
1 parent eee42e4 commit 380f19a

File tree

4 files changed

+95
-121
lines changed

4 files changed

+95
-121
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ repository = "https://github.com/bzhanglab/webgestalt_rust"
1616
bincode = "1.3.3"
1717
clap = { version = "4.4.15", features = ["derive"] }
1818
owo-colors = { version = "4.0.0", features = ["supports-colors"] }
19+
serde_json = "1.0.116"
1920
webgestalt_lib = { version = "0.3.0", path = "webgestalt_lib" }
2021

2122
[profile.release]

src/main.rs

Lines changed: 89 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
use bincode::deserialize_from;
21
use clap::{Args, Parser};
32
use clap::{Subcommand, ValueEnum};
43
use owo_colors::{OwoColorize, Stream::Stdout, Style};
5-
use std::io::{BufReader, Write};
4+
use std::io::Write;
65
use std::{fs::File, time::Instant};
76
use webgestalt_lib::methods::gsea::GSEAConfig;
87
use webgestalt_lib::methods::multilist::{combine_gmts, MultiListMethod, NormalizationMethod};
98
use webgestalt_lib::methods::nta::NTAConfig;
109
use webgestalt_lib::methods::ora::ORAConfig;
1110
use webgestalt_lib::readers::utils::Item;
1211
use webgestalt_lib::readers::{read_gmt_file, read_rank_file};
13-
use webgestalt_lib::{MalformedError, WebGestaltError};
1412

1513
/// WebGestalt CLI.
1614
/// ORA and GSEA enrichment tool.
@@ -24,8 +22,6 @@ struct CliArgs {
2422

2523
#[derive(Subcommand)]
2624
enum Commands {
27-
/// Benchmark different file formats for gmt. TODO: Remove later
28-
Benchmark,
2925
/// Run provided examples for various types of analyses
3026
Example(ExampleArgs),
3127
/// Run GSEA on the provided files
@@ -34,8 +30,6 @@ enum Commands {
3430
Ora(ORAArgs),
3531
/// Run NTA on the provided files
3632
Nta(NtaArgs),
37-
/// Run a test
38-
Test,
3933
/// Combine multiple files into a single file
4034
Combine(CombineArgs),
4135
}
@@ -65,7 +59,7 @@ struct NtaArgs {
6559
seeds: String,
6660
/// Output path for the results
6761
#[arg(short, long)]
68-
out: String,
62+
output: String,
6963
/// Probability of random walk resetting
7064
#[arg(short, long, default_value = "0.5")]
7165
reset_probability: f64,
@@ -77,8 +71,8 @@ struct NtaArgs {
7771
neighborhood_size: usize,
7872
/// Method to use for NTA
7973
/// Options: prioritize, expand
80-
#[arg(short, long)]
81-
method: Option<NTAMethodClap>,
74+
#[arg(short, long, default_value = "prioritize")]
75+
method: NTAMethodClap,
8276
}
8377

8478
#[derive(ValueEnum, Clone)]
@@ -90,19 +84,29 @@ enum NTAMethodClap {
9084
#[derive(Args)]
9185
struct GseaArgs {
9286
/// Path to the GMT file of interest
93-
gmt: Option<String>,
87+
#[arg(short, long)]
88+
gmt: String,
9489
/// Path to the rank file of interest
95-
rnk: Option<String>,
90+
#[arg(short, long)]
91+
rnk: String,
92+
/// Output path for the results
93+
#[arg(short, long, default_value = "out.json")]
94+
output: String,
9695
}
97-
98-
#[derive(Args)]
96+
#[derive(Parser)]
9997
struct ORAArgs {
10098
/// Path to the GMT file of interest
101-
gmt: Option<String>,
99+
#[arg(short, long)]
100+
gmt: String,
102101
/// Path to the file containing the interesting analytes
103-
interest: Option<String>,
102+
#[arg(short, long)]
103+
interest: String,
104+
/// Output path for the results
105+
#[arg(short, long, default_value = "out.json")]
106+
output: String,
104107
/// Path the file containing the reference list
105-
reference: Option<String>,
108+
#[arg(short, long)]
109+
reference: String,
106110
}
107111

108112
#[derive(Args)]
@@ -146,13 +150,43 @@ struct CombineListArgs {
146150
files: Vec<String>,
147151
}
148152

153+
fn prompt_yes_no(question: &str) -> bool {
154+
loop {
155+
print!("{} (y/n): ", question);
156+
std::io::stdout().flush().expect("Could not flush stdout!"); // Ensure the prompt is displayed
157+
158+
let mut input = String::new();
159+
std::io::stdin()
160+
.read_line(&mut input)
161+
.expect("Could not read line");
162+
print!("\x1B[2J\x1B[1;1H");
163+
std::io::stdout().flush().expect("Could not flush stdout!");
164+
match input.trim().to_lowercase().as_str() {
165+
"y" => return true,
166+
"n" => return false,
167+
_ => println!("Invalid input. Please enter 'y' or 'n'."),
168+
}
169+
}
170+
}
171+
172+
fn check_and_overwrite(file_path: &str) {
173+
// Check if the file exists
174+
if std::path::Path::new(file_path).exists() {
175+
// Check if the user wants to overwrite the file
176+
if !prompt_yes_no(&format!(
177+
"File at {} already exists. Do you want to overwrite it?",
178+
file_path
179+
)) {
180+
println!("Stopping analysis.");
181+
std::process::exit(1);
182+
};
183+
}
184+
}
185+
149186
fn main() {
150187
println!("WebGestalt CLI v{}", env!("CARGO_PKG_VERSION"));
151188
let args = CliArgs::parse();
152189
match &args.command {
153-
Some(Commands::Benchmark) => {
154-
benchmark();
155-
}
156190
Some(Commands::Example(ex)) => match &ex.commands {
157191
Some(ExampleOptions::Gsea) => {
158192
let gene_list = webgestalt_lib::readers::read_rank_file(
@@ -177,7 +211,7 @@ fn main() {
177211
"webgestalt_lib/data/genelist.txt".to_owned(),
178212
"webgestalt_lib/data/reference.txt".to_owned(),
179213
);
180-
let gmtcount = gmt.len();
214+
let gmt_count = gmt.len();
181215
let start = Instant::now();
182216
let x: Vec<webgestalt_lib::methods::ora::ORAResult> =
183217
webgestalt_lib::methods::ora::get_ora(
@@ -187,6 +221,8 @@ fn main() {
187221
ORAConfig::default(),
188222
);
189223
let mut count = 0;
224+
let output_file = File::create("test.json").expect("Could not create output file!");
225+
serde_json::to_writer(output_file, &x).expect("Could not create JSON file!");
190226
for i in x {
191227
if i.p < 0.05 && i.fdr < 0.05 {
192228
println!("{}: {}, {}, {}", i.set, i.p, i.fdr, i.overlap);
@@ -196,56 +232,48 @@ fn main() {
196232
let duration = start.elapsed();
197233
println!(
198234
"ORA\nTime took: {:?}\nFound {} significant pathways out of {} pathways",
199-
duration, count, gmtcount
235+
duration, count, gmt_count
200236
);
201237
}
202238
_ => {
203239
println!("Please select a valid example: ora or gsea.");
204240
}
205241
},
206242
Some(Commands::Gsea(gsea_args)) => {
207-
let style = Style::new().red().bold();
208-
if gsea_args.gmt.is_none() || gsea_args.rnk.is_none() {
209-
println!(
210-
"{}: DID NOT PROVIDE PATHS FOR GMT AND RANK FILE.",
211-
"ERROR".if_supports_color(Stdout, |text| text.style(style))
212-
);
213-
return;
214-
}
215-
let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone().unwrap())
243+
check_and_overwrite(&gsea_args.output);
244+
let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone())
216245
.unwrap_or_else(|_| {
217-
panic!("File {} not found", gsea_args.rnk.clone().unwrap());
218-
});
219-
let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone().unwrap())
220-
.unwrap_or_else(|_| {
221-
panic!("File {} not found", gsea_args.gmt.clone().unwrap());
246+
panic!("File {} not found", gsea_args.rnk.clone());
222247
});
248+
let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone()).unwrap_or_else(
249+
|_| {
250+
panic!("File {} not found", gsea_args.gmt.clone());
251+
},
252+
);
223253
let res =
224254
webgestalt_lib::methods::gsea::gsea(gene_list, gmt, GSEAConfig::default(), None);
255+
let output_file =
256+
File::create(&gsea_args.output).expect("Could not create output file!");
257+
serde_json::to_writer(output_file, &res).expect("Could not create JSON file!");
225258
let mut count = 0;
226259
for i in res {
227260
if i.p < 0.05 && i.fdr < 0.05 {
228261
println!("{}: {}, {}", i.set, i.p, i.fdr);
229262
count += 1;
230263
}
231264
}
232-
println!("Done with GSEA: {}", count);
265+
println!(
266+
"Done with GSEA and found {} significant analyte sets",
267+
count
268+
);
233269
}
234270
Some(Commands::Ora(ora_args)) => {
235-
let style = Style::new().red().bold();
236-
if ora_args.gmt.is_none() || ora_args.interest.is_none() || ora_args.reference.is_none()
237-
{
238-
println!(
239-
"{}: DID NOT PROVIDE PATHS FOR GMT, INTEREST, AND REFERENCE FILE.",
240-
"ERROR".if_supports_color(Stdout, |text| text.style(style))
241-
);
242-
return;
243-
}
271+
check_and_overwrite(&ora_args.output);
244272
let start = Instant::now();
245273
let (gmt, interest, reference) = webgestalt_lib::readers::read_ora_files(
246-
ora_args.gmt.clone().unwrap(),
247-
ora_args.interest.clone().unwrap(),
248-
ora_args.reference.clone().unwrap(),
274+
ora_args.gmt.clone(),
275+
ora_args.interest.clone(),
276+
ora_args.reference.clone(),
249277
);
250278
println!("Reading Took {:?}", start.elapsed());
251279
let start = Instant::now();
@@ -255,6 +283,9 @@ fn main() {
255283
gmt,
256284
ORAConfig::default(),
257285
);
286+
let output_file =
287+
File::create(&ora_args.output).expect("Could not create output file!");
288+
serde_json::to_writer(output_file, &res).expect("Could not create JSON file!");
258289
println!("Analysis Took {:?}", start.elapsed());
259290
let mut count = 0;
260291
for row in res.iter() {
@@ -263,42 +294,33 @@ fn main() {
263294
}
264295
}
265296
println!(
266-
"Found {} significant pathways out of {} pathways",
297+
"Found {} significant analyte sets out of {} sets",
267298
count,
268299
res.len()
269300
);
270301
}
271-
Some(Commands::Test) => will_err(1).unwrap_or_else(|x| println!("{}", x)),
272302
Some(Commands::Nta(nta_args)) => {
273-
let style = Style::new().fg_rgb::<255, 179, 71>().bold();
303+
check_and_overwrite(&nta_args.output);
274304
let network = webgestalt_lib::readers::read_edge_list(nta_args.network.clone());
275305
let start = Instant::now();
276-
if nta_args.method.is_none() {
277-
println!(
278-
"{}: DID NOT PROVIDE A METHOD FOR NTA. USING DEFAULT EXPAND METHOD.",
279-
"WARNING".if_supports_color(Stdout, |text| text.style(style))
280-
);
281-
};
282306
let nta_method = match nta_args.method {
283-
Some(NTAMethodClap::Prioritize) => webgestalt_lib::methods::nta::NTAMethod::Prioritize(
284-
nta_args.neighborhood_size,
285-
),
286-
Some(NTAMethodClap::Expand) => webgestalt_lib::methods::nta::NTAMethod::Expand(
287-
nta_args.neighborhood_size,
288-
),
289-
None => webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size),
307+
NTAMethodClap::Prioritize => {
308+
webgestalt_lib::methods::nta::NTAMethod::Prioritize(nta_args.neighborhood_size)
309+
}
310+
NTAMethodClap::Expand => {
311+
webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size)
312+
}
290313
};
291314
let config: NTAConfig = NTAConfig {
292315
edge_list: network,
293316
seeds: webgestalt_lib::readers::read_seeds(nta_args.seeds.clone()),
294317
reset_probability: nta_args.reset_probability,
295318
tolerance: nta_args.tolerance,
296319
method: Some(nta_method),
297-
298320
};
299321
let res = webgestalt_lib::methods::nta::get_nta(config);
300322
println!("Analysis Took {:?}", start.elapsed());
301-
webgestalt_lib::writers::save_nta(nta_args.out.clone(), res).unwrap();
323+
webgestalt_lib::writers::save_nta(nta_args.output.clone(), res).unwrap();
302324
}
303325
Some(Commands::Combine(args)) => match &args.combine_type {
304326
Some(CombineType::Gmt(gmt_args)) => {
@@ -374,53 +396,3 @@ fn main() {
374396
}
375397
}
376398
}
377-
378-
fn benchmark() {
379-
let mut bin_durations: Vec<f64> = Vec::new();
380-
for _i in 0..1000 {
381-
let start = Instant::now();
382-
let mut r = BufReader::new(File::open("test.gmt.wga").unwrap());
383-
let _x: Vec<webgestalt_lib::readers::utils::Item> = deserialize_from(&mut r).unwrap();
384-
let duration = start.elapsed();
385-
bin_durations.push(duration.as_secs_f64())
386-
}
387-
let mut gmt_durations: Vec<f64> = Vec::new();
388-
for _i in 0..1000 {
389-
let start = Instant::now();
390-
let _x = webgestalt_lib::readers::read_gmt_file("webgestalt_lib/data/ktest.gmt".to_owned())
391-
.unwrap();
392-
let duration = start.elapsed();
393-
gmt_durations.push(duration.as_secs_f64())
394-
}
395-
let gmt_avg: f64 = gmt_durations.iter().sum::<f64>() / gmt_durations.len() as f64;
396-
let bin_avg: f64 = bin_durations.iter().sum::<f64>() / bin_durations.len() as f64;
397-
let improvement: f64 = 100.0 * (gmt_avg - bin_avg) / gmt_avg;
398-
println!(
399-
" GMT time: {}\tGMT.WGA time: {}\n Improvement: {:.1}%",
400-
gmt_avg, bin_avg, improvement
401-
);
402-
let mut whole_file: Vec<String> = Vec::new();
403-
whole_file.push("type\ttime".to_string());
404-
for line in bin_durations {
405-
whole_file.push(format!("bin\t{:?}", line));
406-
}
407-
for line in gmt_durations {
408-
whole_file.push(format!("gmt\t{:?}", line));
409-
}
410-
let mut ftsv = File::create("format_benchmarks.tsv").unwrap();
411-
writeln!(ftsv, "{}", whole_file.join("\n")).unwrap();
412-
}
413-
414-
fn will_err(x: i32) -> Result<(), WebGestaltError> {
415-
if x == 0 {
416-
Ok(())
417-
} else {
418-
Err(WebGestaltError::MalformedFile(MalformedError {
419-
path: String::from("ExamplePath.txt"),
420-
kind: webgestalt_lib::MalformedErrorType::WrongFormat {
421-
found: String::from("GMT"),
422-
expected: String::from("rank"),
423-
},
424-
}))
425-
}
426-
}

webgestalt_lib/src/readers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ pub fn read_intersection_list(path: String, ref_list: &AHashSet<String>) -> AHas
132132
}
133133

134134
/// Read edge list from specified path. Separated by whitespace with no support for weights
135-
///
135+
///
136136
/// # Parameters
137137
/// path - A [`String`] of the path of the edge list to read.
138-
///
138+
///
139139
/// # Returns
140140
/// A [`Vec<Vec<String>>`] containing the edge list
141141
pub fn read_edge_list(path: String) -> Vec<Vec<String>> {

0 commit comments

Comments
 (0)