Skip to content

Commit 4fbccff

Browse files
committed
Add missing ygs_sort.rs module for YGS (SGD-groom-sort) pipeline
1 parent 88de776 commit 4fbccff

File tree

1 file changed

+269
-0
lines changed

1 file changed

+269
-0
lines changed

src/ygs_sort.rs

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
/// Exact reimplementation of `odgi sort -p Ygs`
2+
///
3+
/// This module implements the three-stage sorting pipeline:
4+
/// 1. Y - Path-guided stochastic gradient descent (PG-SGD)
5+
/// 2. g - Grooming to remove spurious inverting links
6+
/// 3. s - Topological sort using head nodes
7+
///
8+
/// Based on ODGI's implementation in src/subcommand/sort_main.cpp
9+
10+
use crate::bidirected_ops::BidirectedGraph;
11+
use crate::path_sgd::{PathSGDParams, path_sgd_sort};
12+
13+
/// Parameters for the Ygs sorting pipeline
14+
/// These match ODGI's defaults for the Ygs pipeline
15+
#[derive(Clone)]
16+
pub struct YgsParams {
17+
/// Path SGD parameters
18+
pub path_sgd: PathSGDParams,
19+
/// Whether to print progress information
20+
pub verbose: bool,
21+
}
22+
23+
impl Default for YgsParams {
24+
fn default() -> Self {
25+
YgsParams {
26+
path_sgd: PathSGDParams {
27+
iter_max: 30, // ODGI default for production use
28+
iter_with_max_learning_rate: 0,
29+
min_term_updates: 0, // Will be calculated based on paths
30+
delta: 0.0,
31+
eps: 0.01,
32+
eta_max: 0.0, // Will be calculated from path lengths
33+
theta: 0.99,
34+
space: 0, // Will be calculated from longest path
35+
space_max: 100,
36+
space_quantization_step: 100,
37+
cooling_start: 0.5,
38+
nthreads: 1,
39+
progress: false,
40+
},
41+
verbose: false,
42+
}
43+
}
44+
}
45+
46+
impl YgsParams {
47+
/// Create parameters with calculated defaults based on graph structure
48+
/// This matches how ODGI calculates the parameters in sort_main.cpp
49+
pub fn from_graph(graph: &BidirectedGraph, verbose: bool, nthreads: usize) -> Self {
50+
let mut params = Self::default();
51+
params.verbose = verbose;
52+
params.path_sgd.nthreads = nthreads;
53+
params.path_sgd.progress = verbose;
54+
55+
// Calculate parameters based on graph structure
56+
// Build a temporary path index to get statistics
57+
let path_index = crate::path_sgd::PathIndex::from_graph(graph);
58+
59+
// Calculate sum of path step counts
60+
let mut sum_path_step_count = 0u64;
61+
let mut max_path_step_count = 0usize;
62+
let mut max_path_length = 0usize;
63+
64+
for i in 0..path_index.num_paths() {
65+
let step_count = path_index.get_path_step_count(i);
66+
sum_path_step_count += step_count as u64;
67+
max_path_step_count = max_path_step_count.max(step_count);
68+
max_path_length = max_path_length.max(path_index.get_path_length(i));
69+
}
70+
71+
// Set min_term_updates (ODGI default: 1.0 * sum_path_step_count)
72+
params.path_sgd.min_term_updates = sum_path_step_count;
73+
74+
// Set eta_max (ODGI default: max_path_step_count^2)
75+
params.path_sgd.eta_max = (max_path_step_count * max_path_step_count) as f64;
76+
77+
// Set space (ODGI default: max path length)
78+
params.path_sgd.space = max_path_length as u64;
79+
80+
if verbose {
81+
eprintln!("[ygs_sort] Calculated parameters:");
82+
eprintln!(" sum_path_step_count: {}", sum_path_step_count);
83+
eprintln!(" max_path_step_count: {}", max_path_step_count);
84+
eprintln!(" max_path_length: {}", max_path_length);
85+
eprintln!(" min_term_updates: {}", params.path_sgd.min_term_updates);
86+
eprintln!(" eta_max: {}", params.path_sgd.eta_max);
87+
eprintln!(" space: {}", params.path_sgd.space);
88+
}
89+
90+
params
91+
}
92+
}
93+
94+
/// Apply the Ygs sorting pipeline to a graph
95+
/// This exactly replicates `odgi sort -p Ygs`
96+
pub fn ygs_sort(graph: &mut BidirectedGraph, params: &YgsParams) {
97+
if params.verbose {
98+
eprintln!("[ygs_sort] Starting Ygs pipeline (Y=SGD, g=groom, s=topological_sort)");
99+
eprintln!("[ygs_sort] Initial graph: {} nodes, {} edges",
100+
graph.nodes.len(), graph.edges.len());
101+
}
102+
103+
// Step 1: Y - Path-guided SGD sort
104+
if params.verbose {
105+
eprintln!("[ygs_sort] === Step 1/3: Path-guided SGD (Y) ===");
106+
}
107+
108+
let sgd_ordering = path_sgd_sort(graph, params.path_sgd.clone());
109+
graph.apply_ordering(sgd_ordering, params.verbose);
110+
111+
if params.verbose {
112+
eprintln!("[ygs_sort] After SGD: {} nodes", graph.nodes.len());
113+
}
114+
115+
// Step 2: g - Groom the graph
116+
if params.verbose {
117+
eprintln!("[ygs_sort] === Step 2/3: Grooming (g) ===");
118+
}
119+
120+
let groomed_order = graph.groom(true, params.verbose); // Use BFS like ODGI
121+
graph.apply_grooming_with_reorder(groomed_order, false, params.verbose);
122+
123+
if params.verbose {
124+
eprintln!("[ygs_sort] After grooming: {} nodes", graph.nodes.len());
125+
}
126+
127+
// Step 3: s - Topological sort (heads only)
128+
if params.verbose {
129+
eprintln!("[ygs_sort] === Step 3/3: Topological sort (s) ===");
130+
}
131+
132+
// use_heads=true, use_tails=false matches ODGI's 's' command
133+
let topo_order = graph.exact_odgi_topological_order(true, false, params.verbose);
134+
graph.apply_ordering(topo_order, params.verbose);
135+
136+
if params.verbose {
137+
eprintln!("[ygs_sort] After topological sort: {} nodes", graph.nodes.len());
138+
eprintln!("[ygs_sort] === Ygs pipeline complete ===");
139+
}
140+
}
141+
142+
/// Apply just the topological sort step (the 's' part)
143+
/// This is useful for testing or for applying just the final sort
144+
pub fn topological_sort_only(graph: &mut BidirectedGraph, verbose: bool) {
145+
if verbose {
146+
eprintln!("[topological_sort] Starting topological sort (heads only)");
147+
}
148+
149+
let order = graph.exact_odgi_topological_order(true, false, verbose);
150+
graph.apply_ordering(order, verbose);
151+
152+
if verbose {
153+
eprintln!("[topological_sort] Complete");
154+
}
155+
}
156+
157+
/// Apply just the grooming step (the 'g' part)
158+
pub fn groom_only(graph: &mut BidirectedGraph, verbose: bool) {
159+
if verbose {
160+
eprintln!("[groom] Starting grooming");
161+
}
162+
163+
let groomed_order = graph.groom(true, verbose);
164+
graph.apply_grooming_with_reorder(groomed_order, false, verbose);
165+
166+
if verbose {
167+
eprintln!("[groom] Complete");
168+
}
169+
}
170+
171+
/// Apply the SGD step (the 'Y' part)
172+
pub fn sgd_sort_only(graph: &mut BidirectedGraph, params: PathSGDParams, verbose: bool) {
173+
if verbose {
174+
eprintln!("[path_sgd] Starting path-guided SGD");
175+
}
176+
177+
let ordering = path_sgd_sort(graph, params);
178+
graph.apply_ordering(ordering, verbose);
179+
180+
if verbose {
181+
eprintln!("[path_sgd] Complete");
182+
}
183+
}
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use super::*;
188+
use crate::bidirected_graph::{BiPath, Handle};
189+
190+
fn create_test_graph() -> BidirectedGraph {
191+
let mut graph = BidirectedGraph::new();
192+
193+
// Create a simple linear graph: 1 -> 2 -> 3
194+
graph.add_node(1, b"AAAA".to_vec());
195+
graph.add_node(2, b"CCCC".to_vec());
196+
graph.add_node(3, b"GGGG".to_vec());
197+
198+
graph.add_edge(Handle::forward(1), Handle::forward(2));
199+
graph.add_edge(Handle::forward(2), Handle::forward(3));
200+
201+
// Add a path through all nodes
202+
let mut path = BiPath::new("test_path".to_string());
203+
path.add_step(Handle::forward(1));
204+
path.add_step(Handle::forward(2));
205+
path.add_step(Handle::forward(3));
206+
graph.paths.push(path);
207+
208+
graph
209+
}
210+
211+
#[test]
212+
fn test_ygs_params_default() {
213+
let params = YgsParams::default();
214+
assert_eq!(params.path_sgd.iter_max, 30);
215+
assert_eq!(params.path_sgd.theta, 0.99);
216+
assert_eq!(params.path_sgd.eps, 0.01);
217+
}
218+
219+
#[test]
220+
fn test_ygs_params_from_graph() {
221+
let graph = create_test_graph();
222+
let params = YgsParams::from_graph(&graph, false, 1);
223+
224+
// Check that parameters were calculated
225+
assert!(params.path_sgd.min_term_updates > 0);
226+
assert!(params.path_sgd.eta_max > 0.0);
227+
assert!(params.path_sgd.space > 0);
228+
}
229+
230+
#[test]
231+
fn test_ygs_sort_runs() {
232+
let mut graph = create_test_graph();
233+
let params = YgsParams::from_graph(&graph, false, 1);
234+
235+
// This should not panic
236+
ygs_sort(&mut graph, &params);
237+
238+
// Graph should still be valid
239+
assert_eq!(graph.nodes.len(), 3);
240+
assert!(graph.paths.len() > 0);
241+
}
242+
243+
#[test]
244+
fn test_individual_steps() {
245+
let graph = create_test_graph();
246+
247+
// Test SGD only
248+
{
249+
let mut g = graph.clone();
250+
let params = YgsParams::from_graph(&g, false, 1);
251+
sgd_sort_only(&mut g, params.path_sgd, false);
252+
assert_eq!(g.nodes.len(), 3);
253+
}
254+
255+
// Test groom only
256+
{
257+
let mut g = graph.clone();
258+
groom_only(&mut g, false);
259+
assert_eq!(g.nodes.len(), 3);
260+
}
261+
262+
// Test topological sort only
263+
{
264+
let mut g = graph.clone();
265+
topological_sort_only(&mut g, false);
266+
assert_eq!(g.nodes.len(), 3);
267+
}
268+
}
269+
}

0 commit comments

Comments
 (0)