diff --git a/src/graph_operators.rs b/src/graph_operators.rs index 971851af..d291a1c0 100644 --- a/src/graph_operators.rs +++ b/src/graph_operators.rs @@ -1,22 +1,25 @@ use crate::models::operations::OperationFile; use crate::models::{ block_group::BlockGroup, - block_group_edge::BlockGroupEdge, + block_group_edge::{BlockGroupEdge, BlockGroupEdgeData}, + edge::{Edge, EdgeData}, file_types::FileTypes, - node::{PATH_END_NODE_ID, PATH_START_NODE_ID}, + node::Node, operations::{Operation, OperationInfo}, path::Path, path_edge::PathEdge, sample::Sample, + strand::Strand, }; use crate::operation_management::{end_operation, start_operation, OperationError}; use core::ops::Range; +use itertools::Itertools; use rusqlite::Connection; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use thiserror::Error; #[derive(Debug, Error, PartialEq)] -pub enum DeriveGraphError { +pub enum GraphOperationError { #[error("Operation Error: {0}")] OperationError(#[from] OperationError), #[error("Invalid coordinate(s): {0}")] @@ -33,13 +36,13 @@ pub fn get_path( sample_name: Option<&str>, region_name: &str, backbone: Option<&str>, -) -> Result { +) -> Result { let block_group_id = get_block_group_id(conn, collection_name, sample_name, region_name)?; if let Some(backbone) = backbone { let path = BlockGroup::get_path_by_name(conn, block_group_id, backbone); if path.is_none() { - return Err(DeriveGraphError::PathNotFound(format!( + return Err(GraphOperationError::PathNotFound(format!( "No path found with name {}", backbone ))); @@ -60,7 +63,7 @@ pub fn derive_chunks( region_name: &str, backbone: Option<&str>, chunk_ranges: Vec>, -) -> Result { +) -> Result { let mut session = start_operation(conn); let _new_sample = Sample::get_or_create(conn, new_sample_name); @@ -76,24 +79,30 @@ pub fn derive_chunks( let current_path_length = current_path.length(conn); let current_intervaltree = current_path.intervaltree(conn); - let current_edges = PathEdge::edges_for_path(conn, current_path.id); + let current_path_edges = PathEdge::edges_for_path(conn, current_path.id); + let chunk_ranges_length = chunk_ranges.len(); for (i, chunk_range) in chunk_ranges.clone().into_iter().enumerate() { - let new_block_group_name = format!("{}.{}", region_name, i + 1); - let new_block_group = BlockGroup::create( + let child_block_group_name = if chunk_ranges_length > 1 { + format!("{}.{}", region_name, i + 1) + } else { + region_name.to_string() + }; + + let child_block_group = BlockGroup::create( conn, collection_name, Some(new_sample_name), - new_block_group_name.as_str(), + child_block_group_name.as_str(), ); - let new_block_group_id = new_block_group.id; + let child_block_group_id = child_block_group.id; let start_coordinate = chunk_range.start; let end_coordinate = chunk_range.end; if (start_coordinate < 0 || start_coordinate > current_path_length) || (end_coordinate < 0 || end_coordinate > current_path_length) { - return Err(DeriveGraphError::InvalidCoordinate(format!( + return Err(GraphOperationError::InvalidCoordinate(format!( "Start and/or end coordinates ({}, {}) are out of range for the current path.", start_coordinate, end_coordinate ))); @@ -113,53 +122,87 @@ pub fn derive_chunks( let end_block = blocks[blocks.len() - 1]; let end_node_coordinate = end_coordinate - end_block.start + end_block.sequence_start; - let child_block_group_id = BlockGroup::derive_subgraph( + let new_node_ids_by_old = BlockGroup::derive_subgraph( conn, + collection_name, + new_sample_name, parent_block_group_id, &start_block, &end_block, start_node_coordinate, end_node_coordinate, - new_block_group_id, + child_block_group_id, ); let child_block_group_edges = BlockGroupEdge::edges_for_block_group(conn, child_block_group_id); - let new_edge_id_set = child_block_group_edges + + let child_edge_ids_by_key = child_block_group_edges .iter() - .map(|x| x.edge.id) - .collect::>(); + .map(|augmented_edge| { + let edge = &augmented_edge.edge; + ( + ( + edge.source_node_id, + edge.source_coordinate, + edge.source_strand, + edge.target_node_id, + edge.target_coordinate, + edge.target_strand, + ), + edge.id, + ) + }) + .collect::>(); + // The block group method to derive a subgraph creates copies of the nodes from the parent + // block group, to make it easier to then stitch them together later. So we need to use the + // map returned by derive_subgraph to find the edges in the child graph that correspond to + // path edges in the parent graph, and create a new path from the child edges. let mut new_path_edge_ids = vec![]; - let start_edge = ¤t_edges[0]; - if !new_edge_id_set.contains(&start_edge.id) { - let new_start_edge = child_block_group_edges - .iter() - .find(|e| { - e.edge.source_node_id == PATH_START_NODE_ID - && e.edge.target_node_id == start_block.node_id - && e.edge.target_coordinate == start_node_coordinate - }) - .unwrap(); - new_path_edge_ids.push(new_start_edge.edge.id); - } - for current_edge in ¤t_edges { - if new_edge_id_set.contains(¤t_edge.id) { - new_path_edge_ids.push(current_edge.id); + let new_start_target_node_id = new_node_ids_by_old.get(&start_block.node_id).unwrap(); + let new_start_edge = child_block_group_edges + .iter() + .find(|e| { + Node::is_start_node(e.edge.source_node_id) + && e.edge.target_node_id == *new_start_target_node_id + && e.edge.target_coordinate == start_node_coordinate + }) + .unwrap(); + new_path_edge_ids.push(new_start_edge.edge.id); + + for edge in ¤t_path_edges { + let new_source_node_id = new_node_ids_by_old.get(&edge.source_node_id); + let new_target_node_id = new_node_ids_by_old.get(&edge.target_node_id); + if let Some(new_source_node_id) = new_source_node_id { + if let Some(new_target_node_id) = new_target_node_id { + let key = &( + *new_source_node_id, + edge.source_coordinate, + edge.source_strand, + *new_target_node_id, + edge.target_coordinate, + edge.target_strand, + ); + let child_edge_id = child_edge_ids_by_key.get(key); + if let Some(child_edge_id) = child_edge_id { + new_path_edge_ids.push(*child_edge_id); + } + } } } - let end_edge = ¤t_edges[current_edges.len() - 1]; - if !new_edge_id_set.contains(&end_edge.id) { - let new_end_edge = child_block_group_edges - .iter() - .find(|e| { - e.edge.target_node_id == PATH_END_NODE_ID - && e.edge.source_node_id == end_block.node_id - && e.edge.source_coordinate == end_node_coordinate - }) - .unwrap(); - new_path_edge_ids.push(new_end_edge.edge.id); - } + + let new_end_source_node_id = new_node_ids_by_old.get(&end_block.node_id).unwrap(); + let new_end_edge = child_block_group_edges + .iter() + .find(|e| { + Node::is_end_node(e.edge.target_node_id) + && e.edge.source_node_id == *new_end_source_node_id + && e.edge.source_coordinate == end_node_coordinate + }) + .unwrap(); + new_path_edge_ids.push(new_end_edge.edge.id); + Path::create( conn, ¤t_path.name, @@ -187,7 +230,7 @@ pub fn derive_chunks( &summary_str, None, ) - .map_err(DeriveGraphError::OperationError); + .map_err(GraphOperationError::OperationError); println!("Derived chunks successfully."); @@ -199,7 +242,7 @@ fn get_block_group_id( collection_name: &str, parent_sample_name: Option<&str>, region_name: &str, -) -> Result { +) -> Result { let block_groups = Sample::get_block_groups(conn, collection_name, parent_sample_name); for block_group in &block_groups { @@ -208,12 +251,268 @@ fn get_block_group_id( } } - Err(DeriveGraphError::RegionNotFound(format!( + Err(GraphOperationError::RegionNotFound(format!( "No region found with name: {}", region_name ))) } +pub fn make_stitch( + conn: &Connection, + operation_conn: &Connection, + collection_name: &str, + parent_sample_name: Option<&str>, + new_sample_name: &str, + region_names: &Vec<&str>, + new_region_name: &str, +) -> Result { + let mut session = start_operation(conn); + + let _new_sample = Sample::get_or_create(conn, new_sample_name); + let block_groups = Sample::get_block_groups(conn, collection_name, parent_sample_name); + + let mut block_groups_by_name = HashMap::new(); + for block_group in &block_groups { + let block_group_name = block_group.name.as_str(); + if region_names.contains(&block_group_name) { + block_groups_by_name.insert(block_group_name, block_group); + } + } + + let mut source_node_coordinates: Vec<(i64, i64, Strand)> = vec![]; + let mut edges_to_reuse = vec![]; + let mut edges_to_create = vec![]; + let mut concatenated_path_edges = vec![]; + + // Part 1 + // * Collect all the existing edges from the regions to be stitched together + // * Except edges to/from terminal nodes + // * Also build up a list of edges to create to stitch end nodes of one region to start nodes of + // the next region + for region_name in region_names { + if let Some(block_group) = block_groups_by_name.get(region_name) { + let edges = BlockGroupEdge::edges_for_block_group(conn, block_group.id); + + let nonterminal_edges = edges + .iter() + .filter(|edge| !edge.edge.is_start_edge() && !edge.edge.is_end_edge()) + .cloned(); + edges_to_reuse.extend(nonterminal_edges); + + let start_edges = edges + .iter() + .filter(|edge| edge.edge.is_start_edge()) + .collect::>(); + // Add all edges between the end nodes of the previous region and the start nodes of + // this region + for source_node_coordinate in &source_node_coordinates { + for start_edge in &start_edges { + edges_to_create.push(EdgeData { + source_node_id: source_node_coordinate.0, + source_coordinate: source_node_coordinate.1, + source_strand: source_node_coordinate.2, + target_node_id: start_edge.edge.target_node_id, + target_coordinate: start_edge.edge.target_coordinate, + target_strand: start_edge.edge.target_strand, + }); + } + } + + let end_edges = edges.iter().filter(|edge| edge.edge.is_end_edge()); + source_node_coordinates = end_edges + .map(|edge| { + ( + edge.edge.source_node_id, + edge.edge.source_coordinate, + edge.edge.source_strand, + ) + }) + .collect(); + + let current_path = BlockGroup::get_current_path(conn, block_group.id); + concatenated_path_edges.extend(PathEdge::edges_for_path(conn, current_path.id)); + } else { + return Err(GraphOperationError::RegionNotFound(format!( + "No region found with name: {}", + region_name + ))); + } + } + + // Part 2: + // * Add in existing edges from the virtual start node to the start nodes of the first region + // * Add in existing edges from the end nodes of the last region to the virtual end node + let start_region = block_groups_by_name.get(region_names[0]).unwrap(); + let start_region_edges = BlockGroupEdge::edges_for_block_group(conn, start_region.id); + for start_region_edge in &start_region_edges { + if start_region_edge.edge.is_start_edge() { + edges_to_reuse.push(start_region_edge.clone()); + } + } + + let end_region = block_groups_by_name + .get(region_names[region_names.len() - 1]) + .unwrap(); + let end_region_edges = BlockGroupEdge::edges_for_block_group(conn, end_region.id); + for end_region_edge in &end_region_edges { + if end_region_edge.edge.is_end_edge() { + edges_to_reuse.push(end_region_edge.clone()); + } + } + + // Part 3: Set up the block group, set up bg edges for the edges to reuse, create the necessary + // new edges. + // We'll do a bulk create for the bg edges later in one big call, once we have more information + // for the new edges (which will also get bg edges created then) + let child_block_group = BlockGroup::create( + conn, + collection_name, + Some(new_sample_name), + new_region_name, + ); + let child_block_group_id = child_block_group.id; + + let mut bg_edges = edges_to_reuse + .iter() + .map(|edge| BlockGroupEdgeData { + block_group_id: child_block_group_id, + edge_id: edge.edge.id, + chromosome_index: edge.chromosome_index, + phased: edge.phased, + }) + .collect::>(); + + let created_edge_ids = Edge::bulk_create(conn, &edges_to_create); + let created_edges = Edge::bulk_load(conn, &created_edge_ids); + + // Part 4: Set up a new path + let created_edges_by_node_info = created_edges + .iter() + .map(|edge| { + ( + ( + edge.source_node_id, + edge.source_coordinate, + edge.source_strand, + edge.target_node_id, + edge.target_coordinate, + edge.target_strand, + ), + edge.clone(), + ) + }) + .collect::>(); + + let mut stitch_path_edge_ids = vec![]; + for (path_edge1, path_edge2) in concatenated_path_edges.iter().tuple_windows() { + if Node::is_end_node(path_edge1.target_node_id) + && Node::is_start_node(path_edge2.source_node_id) + { + stitch_path_edge_ids.push( + created_edges_by_node_info[&( + path_edge1.source_node_id, + path_edge1.source_coordinate, + path_edge1.source_strand, + path_edge2.target_node_id, + path_edge2.target_coordinate, + path_edge2.target_strand, + )] + .id, + ); + } + } + + let mut new_path_edge_ids = vec![concatenated_path_edges[0].id]; + let mut stitch_count = 0; + for path_edge in &concatenated_path_edges { + if path_edge.is_end_edge() { + if stitch_count < stitch_path_edge_ids.len() { + new_path_edge_ids.push(stitch_path_edge_ids[stitch_count]); + stitch_count += 1; + } + } else if !path_edge.is_start_edge() { + new_path_edge_ids.push(path_edge.id); + } + } + + new_path_edge_ids.push(concatenated_path_edges[concatenated_path_edges.len() - 1].id); + + // Part 5: Create bg edges for the new edges + let mut chromosome_index_counter = edges_to_reuse + .iter() + .max_by(|x, y| x.chromosome_index.cmp(&y.chromosome_index)) + .unwrap() + .chromosome_index + + 1; + + let path_edge_id_set = new_path_edge_ids.iter().collect::>(); + for created_edge in created_edges { + if path_edge_id_set.contains(&created_edge.id) { + bg_edges.push(BlockGroupEdgeData { + block_group_id: child_block_group_id, + edge_id: created_edge.id, + chromosome_index: 0, + phased: 0, + }); + } else { + bg_edges.push(BlockGroupEdgeData { + block_group_id: child_block_group_id, + edge_id: created_edge.id, + chromosome_index: chromosome_index_counter, + phased: 0, + }); + chromosome_index_counter += 1; + } + } + + BlockGroupEdge::bulk_create(conn, &bg_edges); + + Path::create( + conn, + new_region_name, + child_block_group_id, + &new_path_edge_ids, + ); + + let summary_str = format!( + " {}: stitched {} chunks into new graph", + new_sample_name, + region_names.len() + ); + + let op = end_operation( + conn, + operation_conn, + &mut session, + &OperationInfo { + files: vec![OperationFile { + file_path: "".to_string(), + file_type: FileTypes::None, + }], + description: "make stitch".to_string(), + }, + &summary_str, + None, + ); + + match op { + Ok(op) => Ok(op), + Err(e) => match e { + OperationError::NoChanges => { + println!("Stitched graph already exists, nothing updated."); + Ok(Operation { + hash: "".to_string(), + db_uuid: "".to_string(), + parent_hash: None, + branch_id: 0, + change_type: "".to_string(), + }) + } + _ => Err(GraphOperationError::OperationError(e)), + }, + } +} + #[cfg(test)] mod tests { use super::*; @@ -252,7 +551,11 @@ mod tests { .sequence_type("DNA") .sequence("AAAAAAAA") .save(conn); - let insert_node_id = Node::create(conn, insert_sequence.hash.as_str(), None); + let insert_node_id = Node::create( + conn, + insert_sequence.hash.as_str(), + format!("test-insert-a.{}", insert_sequence.hash), + ); let edge_into_insert = Edge::create( conn, insert_start_node_id, @@ -313,7 +616,7 @@ mod tests { .unwrap(); let block_groups = Sample::get_block_groups(conn, "test", Some("test")); - let block_group2 = block_groups.iter().find(|x| x.name == "chr1.1").unwrap(); + let block_group2 = block_groups.iter().find(|x| x.name == "chr1").unwrap(); let all_sequences2 = BlockGroup::get_all_sequences(conn, block_group2.id, false); assert_eq!( @@ -439,4 +742,185 @@ mod tests { let path3 = BlockGroup::get_current_path(conn, block_group3.id); assert_eq!(path3.sequence(conn), "ATCGATCAAGGAACACA"); } + + #[test] + fn derive_chunks_two_inserts_then_stitch() { + let mut fasta_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + fasta_path.push("fixtures/simple.fa"); + let mut fasta_update_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + fasta_update_path.push("fixtures/aa.fa"); + + setup_gen_dir(); + let conn = &get_connection(None); + let db_uuid = metadata::get_db_uuid(conn); + let op_conn = &get_operation_connection(None); + setup_db(op_conn, &db_uuid); + + let collection = "test"; + + import_fasta( + &fasta_path.to_str().unwrap().to_string(), + collection, + None, + false, + conn, + op_conn, + ) + .unwrap(); + + let _ = update_with_fasta( + conn, + op_conn, + collection, + None, + "test1", + "m123", + 3, + 5, + fasta_update_path.to_str().unwrap(), + ); + + let _ = update_with_fasta( + conn, + op_conn, + collection, + Some("test1"), + "test2", + "m123", + 15, + 20, + fasta_update_path.to_str().unwrap(), + ); + + let original_block_groups = Sample::get_block_groups(conn, collection, None); + let original_block_group_id = original_block_groups[0].id; + let all_original_sequences = + BlockGroup::get_all_sequences(conn, original_block_group_id, false); + assert_eq!( + all_original_sequences, + HashSet::from_iter(vec!["ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string(),]) + ); + + let grandchild_block_groups = Sample::get_block_groups(conn, collection, Some("test2")); + let grandchild_block_group_id = grandchild_block_groups[0].id; + let all_grandchild_sequences = + BlockGroup::get_all_sequences(conn, grandchild_block_group_id, false); + assert_eq!( + all_grandchild_sequences, + HashSet::from_iter(vec![ + "ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string(), + "ATCAATCGATCGATCGATCGGGAACACACAGAGA".to_string(), + "ATCGATCGATCGATCAAGGAACACACAGAGA".to_string(), + "ATCAATCGATCGATCAAGGAACACACAGAGA".to_string(), + ]) + ); + + derive_chunks( + conn, + op_conn, + collection, + Some("test2"), + "test3", + "m123", + None, + vec![ + Range { start: 0, end: 1 }, + Range { start: 1, end: 8 }, + Range { start: 8, end: 25 }, + Range { start: 25, end: 31 }, + ], + ) + .unwrap(); + + let block_groups = Sample::get_block_groups(conn, collection, Some("test3")); + let block_group2 = block_groups.iter().find(|x| x.name == "m123.2").unwrap(); + + let all_sequences2 = BlockGroup::get_all_sequences(conn, block_group2.id, false); + assert_eq!( + all_sequences2, + HashSet::from_iter(vec!["TCAATCG".to_string(), "TCGATCG".to_string(),]) + ); + + let path2 = BlockGroup::get_current_path(conn, block_group2.id); + assert_eq!(path2.sequence(conn), "TCAATCG"); + + let block_group3 = block_groups.iter().find(|x| x.name == "m123.3").unwrap(); + let all_sequences3 = BlockGroup::get_all_sequences(conn, block_group3.id, false); + assert_eq!( + all_sequences3, + HashSet::from_iter(vec![ + "ATCGATCAAGGAACACA".to_string(), + "ATCGATCGATCGGGAACACA".to_string(), + ]) + ); + + let path3 = BlockGroup::get_current_path(conn, block_group3.id); + assert_eq!(path3.sequence(conn), "ATCGATCAAGGAACACA"); + + // Stitch the two main chunks back together in same order + make_stitch( + conn, + op_conn, + collection, + Some("test3"), + "test4", + &vec!["m123.2", "m123.3"], + "m123.stitched", + ) + .unwrap(); + + let block_groups = Sample::get_block_groups(conn, collection, Some("test4")); + let block_group4 = block_groups + .iter() + .find(|x| x.name == "m123.stitched") + .unwrap(); + + let all_sequences4 = BlockGroup::get_all_sequences(conn, block_group4.id, false); + assert_eq!( + all_sequences4, + HashSet::from_iter(vec![ + "TCAATCGATCGATCAAGGAACACA".to_string(), + "TCAATCGATCGATCGATCGGGAACACA".to_string(), + "TCGATCGATCGATCAAGGAACACA".to_string(), + "TCGATCGATCGATCGATCGGGAACACA".to_string(), + ]) + ); + + let path4 = BlockGroup::get_current_path(conn, block_group4.id); + // path2 + path3 concatenated + assert_eq!(path4.sequence(conn), "TCAATCGATCGATCAAGGAACACA"); + + // Stitch the two main chunks together but in reverse order + make_stitch( + conn, + op_conn, + collection, + Some("test3"), + "test5", + &vec!["m123.3", "m123.2"], + "m123.reverse-stitched", + ) + .unwrap(); + + let block_groups = Sample::get_block_groups(conn, collection, Some("test5")); + let block_group5 = block_groups + .iter() + .find(|x| x.name == "m123.reverse-stitched") + .unwrap(); + + let all_sequences5 = BlockGroup::get_all_sequences(conn, block_group5.id, false); + assert_eq!( + all_sequences5, + HashSet::from_iter(vec![ + "ATCGATCAAGGAACACATCAATCG".to_string(), + "ATCGATCAAGGAACACATCGATCG".to_string(), + "ATCGATCGATCGGGAACACATCAATCG".to_string(), + "ATCGATCGATCGGGAACACATCGATCG".to_string(), + ]) + ); + + let path5 = BlockGroup::get_current_path(conn, block_group5.id); + // path3 + path2 concatenated + assert_eq!(path5.sequence(conn), "ATCGATCAAGGAACACATCAATCG"); + } } diff --git a/src/main.rs b/src/main.rs index c4bc8b02..d99f1856 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use gen::exports::fasta::export_fasta; use gen::exports::genbank::export_genbank; use gen::exports::gfa::export_gfa; use gen::get_connection; -use gen::graph_operators::{derive_chunks, get_path}; +use gen::graph_operators::{derive_chunks, get_path, make_stitch}; use gen::imports::fasta::{import_fasta, FastaError}; use gen::imports::genbank::import_genbank; use gen::imports::gfa::{import_gfa, GFAImportError}; @@ -420,6 +420,28 @@ enum Commands { #[arg(long)] chunk_size: Option, }, + #[command( + verbatim_doc_comment, + long_about = "Combine multiple sequence graphs into one. Example: + gen make-stitch --sample parent_sample --new-sample my_child_sample --regions chr1.2,chr1.3 --new-region spliced_chr1" + )] + MakeStitch { + /// The name of the collection to derive the subgraph from + #[arg(short, long)] + name: Option, + /// The name of the parent sample + #[arg(short, long)] + sample: Option, + /// The name of the new sample + #[arg(long)] + new_sample: String, + /// The names of the regions to combine + #[arg(long)] + regions: String, + /// The name of the new region + #[arg(long)] + new_region: String, + }, } fn main() { @@ -1226,5 +1248,37 @@ fn main() { conn.execute("END TRANSACTION", []).unwrap(); operation_conn.execute("END TRANSACTION", []).unwrap(); } + Some(Commands::MakeStitch { + name, + sample, + new_sample, + regions, + new_region, + }) => { + conn.execute("BEGIN TRANSACTION", []).unwrap(); + operation_conn.execute("BEGIN TRANSACTION", []).unwrap(); + let name = &name + .clone() + .unwrap_or_else(|| get_default_collection(&operation_conn)); + let sample_name = sample.clone(); + let new_sample_name = new_sample.clone(); + + let region_names = regions.split(",").collect::>(); + + match make_stitch( + &conn, + &operation_conn, + name, + sample_name.as_deref(), + &new_sample_name, + ®ion_names, + new_region, + ) { + Ok(_) => {} + Err(e) => panic!("Error stitching subgraphs: {e}"), + } + conn.execute("END TRANSACTION", []).unwrap(); + operation_conn.execute("END TRANSACTION", []).unwrap(); + } } } diff --git a/src/models/block_group.rs b/src/models/block_group.rs index 3aada42e..68ae8706 100644 --- a/src/models/block_group.rs +++ b/src/models/block_group.rs @@ -820,15 +820,18 @@ impl BlockGroup { None } + #[allow(clippy::too_many_arguments)] pub fn derive_subgraph( conn: &Connection, + collection_name: &str, + child_sample_name: &str, source_block_group_id: i64, start_block: &NodeIntervalBlock, end_block: &NodeIntervalBlock, start_node_coordinate: i64, end_node_coordinate: i64, target_block_group_id: i64, - ) -> i64 { + ) -> HashMap { let current_graph = BlockGroup::get_graph(conn, source_block_group_id); let start_node = current_graph .nodes() @@ -855,6 +858,99 @@ impl BlockGroup { .collect::>(); let source_edges = Edge::bulk_load(conn, &subgraph_edge_ids); + // Instead of reusing the existing edges, we create copies of the nodes and edges. This + // makes it easier to recombine subgraphs later. + // TODO: Annotation support. We do not reuse node ids because nodes control the graph + // topology. If we duplicated nodes via these operations, it would leave to unintentional + // cycles. A node cannot exist in 2 places. This will be non-intuitive for annotation + // propagation and needs to be handled. + let mut old_node_ids = HashSet::new(); + for source_edge in &source_edges { + old_node_ids.insert(source_edge.source_node_id); + old_node_ids.insert(source_edge.target_node_id); + } + + old_node_ids.insert(start_block.node_id); + old_node_ids.insert(end_block.node_id); + + let old_nodes = Node::get_nodes(conn, &old_node_ids.iter().cloned().collect::>()); + let old_nodes_by_id = old_nodes + .iter() + .map(|node| (node.id, node)) + .collect::>(); + + let mut new_node_ids_by_old = HashMap::new(); + let old_start_node = old_nodes_by_id.get(&start_block.node_id).unwrap(); + let new_start_node_hash = format!( + "{}.{}.{}.bg-{}", + collection_name, + child_sample_name, + old_start_node.hash.clone().unwrap(), + target_block_group_id + ); + new_node_ids_by_old.insert( + old_start_node.id, + Node::create(conn, &old_start_node.sequence_hash, new_start_node_hash), + ); + let old_end_node = old_nodes_by_id.get(&end_block.node_id).unwrap(); + let new_end_node_hash = format!( + "{}.{}.{}.bg-{}", + collection_name, + child_sample_name, + old_end_node.hash.clone().unwrap(), + target_block_group_id + ); + new_node_ids_by_old.insert( + old_end_node.id, + Node::create(conn, &old_end_node.sequence_hash, new_end_node_hash), + ); + + let mut new_edges = vec![]; + for source_edge in &source_edges { + let new_source_node_id = *new_node_ids_by_old + .entry(source_edge.source_node_id) + .or_insert_with(|| { + let old_source_node = old_nodes_by_id.get(&source_edge.source_node_id).unwrap(); + let new_source_node_hash = format!( + "{}.{}.{}.bg-{}", + collection_name, + child_sample_name, + old_source_node.hash.clone().unwrap(), + target_block_group_id + ); + Node::create(conn, &old_source_node.sequence_hash, new_source_node_hash) + }); + let new_target_node_id = *new_node_ids_by_old + .entry(source_edge.target_node_id) + .or_insert_with(|| { + let old_target_node = old_nodes_by_id.get(&source_edge.target_node_id).unwrap(); + let new_target_node_hash = format!( + "{}.{}.{}.bg-{}", + collection_name, + child_sample_name, + old_target_node.hash.clone().unwrap(), + target_block_group_id + ); + Node::create(conn, &old_target_node.sequence_hash, new_target_node_hash) + }); + + new_edges.push(EdgeData { + source_node_id: new_source_node_id, + source_coordinate: source_edge.source_coordinate, + source_strand: source_edge.source_strand, + target_node_id: new_target_node_id, + target_coordinate: source_edge.target_coordinate, + target_strand: source_edge.target_strand, + }); + } + + let new_edge_ids = Edge::bulk_create(conn, &new_edges); + let new_edge_ids_by_old = source_edges + .iter() + .zip(new_edge_ids.iter()) + .map(|(source_edge, new_edge_id)| (source_edge.id, *new_edge_id)) + .collect::>(); + let source_block_group_edges = BlockGroupEdge::specific_edges_for_block_group( conn, source_block_group_id, @@ -879,21 +975,23 @@ impl BlockGroup { let block_group_edge = source_block_group_edges_by_edge_id .get(&edge.edge.id) .unwrap(); + let new_edge_id = new_edge_ids_by_old.get(&edge.edge.id).unwrap(); BlockGroupEdgeData { block_group_id: target_block_group_id, - edge_id: edge.edge.id, + edge_id: *new_edge_id, chromosome_index: block_group_edge.chromosome_index, phased: block_group_edge.phased, } }) .collect::>(); + let new_start_node_id = new_node_ids_by_old.get(&start_block.node_id).unwrap(); let new_start_edge = Edge::create( conn, PATH_START_NODE_ID, 0, Strand::Forward, - start_block.node_id, + *new_start_node_id, start_node_coordinate, start_block.strand, ); @@ -903,9 +1001,10 @@ impl BlockGroup { chromosome_index: 0, phased: 0, }; + let new_end_node_id = new_node_ids_by_old.get(&end_block.node_id).unwrap(); let new_end_edge = Edge::create( conn, - end_block.node_id, + *new_end_node_id, end_node_coordinate, end_block.strand, PATH_END_NODE_ID, @@ -923,7 +1022,7 @@ impl BlockGroup { all_edges.push(new_end_edge_data); BlockGroupEdge::bulk_create(conn, &all_edges); - target_block_group_id + new_node_ids_by_old } } @@ -2433,7 +2532,11 @@ mod tests { .sequence_type("DNA") .sequence("AAAAAAAA") .save(conn); - let insert_node_id = Node::create(conn, insert_sequence.hash.as_str(), None); + let insert_node_id = Node::create( + conn, + insert_sequence.hash.as_str(), + format!("test-insert-a-node.{}", insert_sequence.hash), + ); let edge_into_insert = Edge::create( conn, insert_start_node_id, @@ -2494,6 +2597,8 @@ mod tests { let block_group2 = BlockGroup::create(conn, "test", None, "chr1.1"); BlockGroup::derive_subgraph( conn, + "test", + "test", block_group1_id, &start_block, &end_block, @@ -2527,7 +2632,11 @@ mod tests { .sequence_type("DNA") .sequence("AAAAAAAA") .save(conn); - let insert_node_id = Node::create(conn, insert_sequence.hash.as_str(), None); + let insert_node_id = Node::create( + conn, + insert_sequence.hash.as_str(), + format!("test-insert-a-node.{}", insert_sequence.hash), + ); let edge_into_insert = Edge::create( conn, insert_start_node_id, @@ -2573,7 +2682,11 @@ mod tests { .sequence_type("DNA") .sequence("TTTTTTTT") .save(conn); - let insert2_node_id = Node::create(conn, insert2_sequence.hash.as_str(), None); + let insert2_node_id = Node::create( + conn, + insert2_sequence.hash.as_str(), + format!("test-insert-t-node.{}", insert2_sequence.hash), + ); let edge_into_insert2 = Edge::create( conn, insert2_start_node_id, @@ -2636,6 +2749,8 @@ mod tests { let block_group2 = BlockGroup::create(conn, "test", None, "chr1.1"); BlockGroup::derive_subgraph( conn, + "test", + "test", block_group1_id, &start_block, &end_block, @@ -2677,7 +2792,11 @@ mod tests { .sequence_type("DNA") .sequence("AAAAAAAA") .save(conn); - let insert_node_id = Node::create(conn, insert_sequence.hash.as_str(), None); + let insert_node_id = Node::create( + conn, + insert_sequence.hash.as_str(), + format!("test-insert-a-node.{}", insert_sequence.hash), + ); let edge_into_insert = Edge::create( conn, insert_start_node_id, @@ -2723,7 +2842,11 @@ mod tests { .sequence_type("DNA") .sequence("TTTTTTTT") .save(conn); - let insert2_node_id = Node::create(conn, insert2_sequence.hash.as_str(), None); + let insert2_node_id = Node::create( + conn, + insert2_sequence.hash.as_str(), + format!("test-insert-t-node.{}", insert2_sequence.hash), + ); let edge_into_insert2 = Edge::create( conn, insert2_start_node_id, @@ -2805,6 +2928,8 @@ mod tests { let block_group2 = BlockGroup::create(conn, "test", None, "chr1.1"); BlockGroup::derive_subgraph( conn, + "test", + "test", block_group1_id, &start_block, &end_block, diff --git a/src/models/block_group_edge.rs b/src/models/block_group_edge.rs index e9cece75..27f5f647 100644 --- a/src/models/block_group_edge.rs +++ b/src/models/block_group_edge.rs @@ -6,7 +6,7 @@ use rusqlite::{Connection, Row}; use std::collections::HashMap; use std::rc::Rc; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct BlockGroupEdge { pub id: i64, pub block_group_id: i64, @@ -15,7 +15,7 @@ pub struct BlockGroupEdge { pub phased: i64, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct BlockGroupEdgeData { pub block_group_id: i64, pub edge_id: i64, @@ -23,14 +23,14 @@ pub struct BlockGroupEdgeData { pub phased: i64, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct AugmentedEdge { pub edge: Edge, pub chromosome_index: i64, pub phased: i64, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct AugmentedEdgeData { pub edge_data: EdgeData, pub chromosome_index: i64, diff --git a/src/models/edge.rs b/src/models/edge.rs index e36c8d01..224990a9 100644 --- a/src/models/edge.rs +++ b/src/models/edge.rs @@ -14,7 +14,7 @@ use crate::models::sequence::{cached_sequence, Sequence}; use crate::models::strand::Strand; use crate::models::traits::*; -#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize, Ord, PartialOrd)] pub struct Edge { pub id: i64, pub source_node_id: i64, @@ -25,7 +25,7 @@ pub struct Edge { pub target_strand: Strand, } -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct EdgeData { pub source_node_id: i64, pub source_coordinate: i64, @@ -493,6 +493,14 @@ impl Edge { } boundary_edges } + + pub fn is_start_edge(&self) -> bool { + self.source_node_id == PATH_START_NODE_ID + } + + pub fn is_end_edge(&self) -> bool { + self.target_node_id == PATH_END_NODE_ID + } } #[cfg(test)] diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 6a5073d3..39a7e455 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -74,22 +74,38 @@ pub fn setup_block_group(conn: &Connection) -> (i64, Path) { .sequence_type("DNA") .sequence("AAAAAAAAAA") .save(conn); - let a_node_id = Node::create(conn, a_seq.hash.as_str(), None); + let a_node_id = Node::create( + conn, + a_seq.hash.as_str(), + format!("test-a-node.{}", a_seq.hash), + ); let t_seq = Sequence::new() .sequence_type("DNA") .sequence("TTTTTTTTTT") .save(conn); - let t_node_id = Node::create(conn, t_seq.hash.as_str(), None); + let t_node_id = Node::create( + conn, + t_seq.hash.as_str(), + format!("test-t-node.{}", a_seq.hash), + ); let c_seq = Sequence::new() .sequence_type("DNA") .sequence("CCCCCCCCCC") .save(conn); - let c_node_id = Node::create(conn, c_seq.hash.as_str(), None); + let c_node_id = Node::create( + conn, + c_seq.hash.as_str(), + format!("test-c-node.{}", a_seq.hash), + ); let g_seq = Sequence::new() .sequence_type("DNA") .sequence("GGGGGGGGGG") .save(conn); - let g_node_id = Node::create(conn, g_seq.hash.as_str(), None); + let g_node_id = Node::create( + conn, + g_seq.hash.as_str(), + format!("test-g-node.{}", a_seq.hash), + ); let _collection = Collection::create(conn, "test"); let block_group = BlockGroup::create(conn, "test", None, "chr1"); let edge0 = Edge::create(