Status: Proposal
Version: 0.1.0-draft
Target Audience: Rust library authors, ML engineers, systems programmers
Python's scikit-learn is arguably the most successful machine learning library ever built. Its success comes not from raw performance, but from a carefully designed API surface: consistent interfaces, composable pipelines, and a principled separation between data, transformers, and estimators. The Rust ecosystem has fragments — linfa covers some classical ML, ndarray handles N-dimensional arrays, polars handles dataframes — but nothing unifies them under a coherent, ergonomic, production-ready framework.
ferrolearn (working name) should not be a port of scikit-learn. It should be a Rust-native design that achieves the same outcomes: a toolkit where a practitioner can go from raw data to a trained, validated, serializable model without leaving the library ecosystem or fighting the borrow checker.
- Complete parity with scikit-learn's classical ML algorithm coverage
- Zero mandatory unsafe in the public API
- Composable pipelines with full type safety at compile time
- Native sparse matrix support — a first-class citizen, not an afterthought
- Serialization/deserialization of trained models via
serde - Optional GPU acceleration via a pluggable backend trait
no_stdcompatibility for core algorithms where feasible- Interoperability with
ndarray,polars, andarrow - Formally verified accuracy parity with scikit-learn via a mandatory six-layer correctness stack (see Section 20)
- A deep learning framework (that is a separate domain)
- Python bindings (a separate crate should handle this via
pyo3) - A data visualization layer (delegate to
plottersor similar) - Replacing
ndarrayorfaeras the linear algebra primitive
Everything in the library implements at least one of three core traits. This mirrors scikit-learn's design but encodes it in Rust's type system.
/// A model that can be fit to data.
pub trait Fit<X, Y> {
type FitResult: Predict<X>;
type Error: std::error::Error;
fn fit(&self, x: &X, y: &Y) -> Result<Self::FitResult, Self::Error>;
}
/// A fitted model that can make predictions.
pub trait Predict<X> {
type Output;
type Error: std::error::Error;
fn predict(&self, x: &X) -> Result<Self::Output, Self::Error>;
}
/// A stateful data transformer (e.g. StandardScaler).
pub trait Transform<X> {
type Output;
type Error: std::error::Error;
fn transform(&self, x: &X) -> Result<Self::Output, Self::Error>;
}
/// Fit and transform in one step.
pub trait FitTransform<X>: Sized {
type Output;
type Fitted: Transform<X>;
type Error: std::error::Error;
fn fit_transform(self, x: &X) -> Result<(Self::Fitted, Self::Output), Self::Error>;
}The key design decision here: fit() consumes parameters and returns a new, fitted type. This enforces at compile time that you cannot call predict() on an unfitted model — there is no unfitted model that implements Predict. This is a strict improvement over scikit-learn, where calling predict() before fit() is a runtime error.
The library must be agnostic over input representation. A blanket-impl strategy over a Dataset trait allows flexibility.
pub trait Dataset {
fn n_samples(&self) -> usize;
fn n_features(&self) -> usize;
fn is_sparse(&self) -> bool { false }
}
// Implementations for:
// - ndarray::Array2<f32>
// - ndarray::Array2<f64>
// - SparseMatrix<f32> (CSR and CSC formats, see Section 8)
// - polars::DataFrame (via feature flag)
// - arrow::RecordBatch (via feature flag)Numeric precision should be parametric. Algorithms should be generic over Float: num_traits::Float + Send + Sync.
The pipeline is the most important ergonomic feature in scikit-learn. In Rust, this requires careful design to avoid losing type information.
// Desired usage:
let pipeline = Pipeline::new()
.step("scaler", StandardScaler::new())
.step("pca", PCA::new().n_components(10))
.step("clf", LogisticRegression::new().max_iter(1000));
let fitted = pipeline.fit(&x_train, &y_train)?;
let predictions = fitted.predict(&x_test)?;Internally, pipeline steps are stored as boxed trait objects to allow heterogeneous step types while keeping the API ergonomic. A compile-time pipeline (using const generics or type-level lists) should be available as an opt-in for zero-cost abstraction.
The following table maps scikit-learn's algorithm categories to required implementations. All algorithms must support both f32 and f64.
| Algorithm | Priority | Notes |
|---|---|---|
| Logistic Regression | P0 | L1, L2, ElasticNet penalties; multi-class via OvR and softmax |
| Linear SVM (SVC) | P0 | SMO solver; kernel trick via pluggable Kernel trait |
| Kernel SVM | P1 | RBF, polynomial, sigmoid kernels |
| k-Nearest Neighbors | P0 | Ball tree and KD-tree backends |
| Decision Tree | P0 | Gini and entropy criteria; max depth, min samples controls |
| Random Forest | P0 | Parallelized via Rayon |
| Gradient Boosting | P0 | Must include histogram-based variant (HistGB) |
| AdaBoost | P1 | |
| Naive Bayes | P0 | Gaussian, Multinomial, Bernoulli, Complement variants |
| Linear Discriminant Analysis | P1 | |
| Quadratic Discriminant Analysis | P2 | |
| Perceptron | P1 | |
| Ridge Classifier | P1 |
| Algorithm | Priority | Notes |
|---|---|---|
| Linear Regression | P0 | OLS via QR decomposition; closed-form and iterative solvers |
| Ridge | P0 | L2 regularization |
| Lasso | P0 | Coordinate descent solver |
| ElasticNet | P0 | |
| Bayesian Ridge | P1 | |
| SGD Regressor | P0 | Mini-batch support; multiple loss functions |
| SVR | P1 | |
| k-Nearest Neighbors Regression | P0 | |
| Decision Tree Regressor | P0 | |
| Random Forest Regressor | P0 | |
| Gradient Boosting Regressor | P0 | |
| Huber Regressor | P1 | |
| Isotonic Regression | P2 | |
| RANSAC | P2 |
| Algorithm | Priority | Notes |
|---|---|---|
| k-Means | P0 | k-Means++ initialization; parallelized via Rayon |
| Mini-Batch k-Means | P1 | |
| DBSCAN | P0 | |
| HDBSCAN | P1 | |
| Agglomerative Clustering | P1 | Ward, complete, average, single linkage |
| Gaussian Mixture Models | P1 | EM algorithm; full, tied, diag, spherical covariance |
| Mean Shift | P2 | |
| Spectral Clustering | P2 | |
| OPTICS | P2 | |
| Birch | P2 |
| Algorithm | Priority | Notes |
|---|---|---|
| PCA | P0 | Full, truncated (randomized SVD), incremental variants |
| Truncated SVD | P0 | Works on sparse matrices — critical for NLP |
| Kernel PCA | P1 | |
| t-SNE | P1 | Barnes-Hut approximation for large datasets |
| UMAP | P1 | |
| LDA (topic model) | P1 | |
| Factor Analysis | P2 | |
| ICA | P2 | |
| NMF | P1 | Multiplicative update and coordinate descent solvers |
| Dictionary Learning | P2 |
| Algorithm | Priority | Notes |
|---|---|---|
| Isomap | P2 | |
| Locally Linear Embedding | P2 | |
| MDS | P2 | |
| Spectral Embedding | P2 |
Preprocessing transformers must implement FitTransform and be pipeline-composable.
| Transformer | Notes |
|---|---|
StandardScaler |
Zero mean, unit variance; handles f32/f64 |
MinMaxScaler |
Configurable feature range |
RobustScaler |
Median and IQR based; resistant to outliers |
MaxAbsScaler |
Scales each feature by its max absolute value |
Normalizer |
Per-sample L1, L2, or max normalization |
PowerTransformer |
Box-Cox and Yeo-Johnson |
QuantileTransformer |
Uniform and Gaussian output distributions |
| Transformer | Notes |
|---|---|
OneHotEncoder |
Handles unknown categories at transform time |
OrdinalEncoder |
|
LabelEncoder |
|
TargetEncoder |
With cross-fitting to prevent leakage |
BinaryEncoder |
| Transformer | Notes |
|---|---|
PolynomialFeatures |
Interaction terms; configurable degree |
SplineTransformer |
B-spline basis expansion |
KBinsDiscretizer |
Uniform, quantile, k-means binning strategies |
Binarizer |
|
FunctionTransformer |
Wrap any Fn as a transformer |
| Transformer | Notes |
|---|---|
SimpleImputer |
Mean, median, most frequent, constant |
KNNImputer |
|
IterativeImputer |
MICE-style multivariate imputation |
| Transformer | Notes |
|---|---|
VarianceThreshold |
|
SelectKBest |
Pluggable scoring functions |
SelectPercentile |
|
RFE |
Recursive Feature Elimination |
RFECV |
RFE with cross-validation |
SelectFromModel |
Threshold on feature_importances_ or coefficients |
SequentialFeatureSelector |
Forward and backward |
let cv = KFold::new(5).shuffle(true).random_state(42);
let scores = cross_val_score(&estimator, &x, &y, &cv, accuracy_score)?;Required splitters:
KFold/StratifiedKFold/GroupKFoldShuffleSplit/StratifiedShuffleSplitTimeSeriesSplitLeaveOneOut/LeavePOutLeaveOneGroupOut/LeavePGroupsOut
let param_grid = param_grid! {
"max_depth" => [3, 5, 10, None],
"min_samples_split" => [2, 5, 10],
};
let search = GridSearchCV::new(estimator, param_grid)
.cv(StratifiedKFold::new(5))
.scoring(f1_score)
.n_jobs(-1); // Use all available threads via Rayon
let fitted_search = search.fit(&x_train, &y_train)?;
println!("Best params: {:?}", fitted_search.best_params());| Search Method | Notes |
|---|---|
GridSearchCV |
Exhaustive grid; parallel via Rayon |
RandomizedSearchCV |
Samples from distributions |
HalvingGridSearchCV |
Successive halving for large grids |
Hyperparameter distributions must be expressible:
pub trait Distribution<T>: Send + Sync {
fn sample(&self, rng: &mut impl Rng) -> T;
}
// Provided: Uniform, LogUniform, Normal, LogNormal, IntUniform, ChoiceAll metrics must work with both owned and borrowed arrays.
Classification:
accuracy_score,balanced_accuracy_scoreprecision_score,recall_score,f1_score,fbeta_scoreroc_auc_score,average_precision_scoreconfusion_matrix,classification_reportmatthews_corrcoef,cohen_kappa_scorelog_loss,brier_score_lossroc_curve,precision_recall_curve
Regression:
mean_absolute_error,mean_squared_error,root_mean_squared_errormean_absolute_percentage_errorr2_score,explained_variance_scoremedian_absolute_errormean_squared_log_errormax_errord2_tweedie_score
Clustering:
adjusted_rand_score,rand_scoreadjusted_mutual_info_score,normalized_mutual_info_scorehomogeneity_score,completeness_score,v_measure_scoresilhouette_score,silhouette_samplescalinski_harabasz_scoredavies_bouldin_score
This is the most commonly cited gap in the Rust ML ecosystem. Sparse support must be a first-class primitive, not an optional extension.
pub enum SparseFormat {
CSR, // Compressed Sparse Row — fast row slicing, matrix-vector products
CSC, // Compressed Sparse Column — fast column slicing
COO, // Coordinate format — good for construction
LIL, // List of Lists — good for incremental construction
DOK, // Dictionary of Keys — fast element access
}All algorithms that make statistical sense on sparse data (TruncatedSVD, NMF, Naive Bayes, linear models) must accept sparse inputs natively without converting to dense.
Required operations:
- Matrix-vector and matrix-matrix multiply (via BLAS-compatible routines)
- Elementwise arithmetic preserving sparsity
- Slicing (rows, columns, arbitrary indexing)
- Conversion between formats
- Vertical and horizontal stacking (
vstack,hstack) toarray()/todense()conversions- Arithmetic with dense arrays
pub struct CsrMatrix<T> {
data: Vec<T>, // Non-zero values
indices: Vec<usize>, // Column indices for each value
indptr: Vec<usize>, // Row pointer array
shape: (usize, usize),
}The library should integrate with or wrap sprs where it is sufficient, but not depend on it for correctness of algorithm implementations.
Parallelism should be opt-in at the algorithm level and transparent to the caller.
let rf = RandomForest::new()
.n_estimators(100)
.n_jobs(-1); // -1 = use all cores, matching scikit-learn conventionInternally, -1 maps to rayon::current_num_threads(). Individual algorithms document their parallelism strategy.
Rules:
- Rayon for CPU parallelism across samples and estimators
- All parallel code must be deterministic when a
random_stateseed is provided - Reproducibility is a hard requirement, not a best-effort
Trained models must be serializable. This is a known gap in the current Rust ML ecosystem.
use ferrolearn::io::{save_model, load_model};
let model = RandomForest::new().n_estimators(100).fit(&x, &y)?;
// Save
save_model(&model, "model.fl")?;
// Load — type must be specified
let loaded: FittedRandomForest<f64> = load_model("model.fl")?;- Native format:
MessagePackorbincode+ metadata envelope. Compact, fast, version-tagged. - JSON export: Human-readable via
serde_jsonfor debugging and interoperability - ONNX export (P1): For deployment interop with runtimes like
tractoronnxruntime - PMML export (P2): For enterprise/legacy system interoperability
Every serialized model must include a schema version. Deserialization must fail fast with a clear error when the version is incompatible, rather than silently producing wrong results.
The library must support a pluggable compute backend to avoid locking into a single linear algebra stack.
pub trait Backend: Send + Sync {
fn gemm<F: Float>(&self, a: &Array2<F>, b: &Array2<F>) -> Array2<F>;
fn svd<F: Float>(&self, a: &Array2<F>) -> (Array2<F>, Array1<F>, Array2<F>);
// ... etc
}
pub struct NdarrayBackend; // Default; pure Rust via ndarray + faer
pub struct BlasBackend; // Links system BLAS/LAPACK (OpenBLAS, MKL)
pub struct CudaBackend; // Optional GPU via cuBLAS (feature flag)This allows the library to be used in environments without BLAS (e.g. embedded, WASM) while still achieving near-LAPACK performance in production deployments where BLAS is available.
All public functions return Result<T, FerroError>. Panics are forbidden in library code except for internal invariant violations (which should be unreachable in correct usage).
#[non_exhaustive]
pub enum FerroError {
ShapeMismatch { expected: Shape, got: Shape },
InsufficientSamples { needed: usize, got: usize },
ConvergenceFailure { iterations: usize, tolerance: f64 },
InvalidParameter { name: &'static str, reason: String },
NumericalInstability { context: String },
IoError(std::io::Error),
SerdeError(String),
}Every error variant must carry enough context to identify the root cause without a debugger.
Fitted models that support it must expose their learned parameters.
pub trait HasCoefficients {
fn coef(&self) -> ArrayView2<f64>;
fn intercept(&self) -> ArrayView1<f64>;
}
pub trait HasFeatureImportances {
fn feature_importances(&self) -> ArrayView1<f64>;
}
pub trait HasClasses<L> {
fn classes(&self) -> &[L];
}This mirrors scikit-learn's coef_, feature_importances_, classes_ attributes. The postfix underscore convention from Python becomes a method in Rust (no underscore needed since methods are not confused with parameters).
The library should ship utilities for working with datasets, both real and synthetic.
use ferrolearn::datasets;
let iris = datasets::load_iris();
let (x, y) = (iris.data, iris.target);Required toy datasets: Iris, Digits, Wine, Breast Cancer, Diabetes, Linnerud, Olivetti Faces.
let (x, y) = make_classification()
.n_samples(1000)
.n_features(20)
.n_informative(10)
.random_state(0)
.generate()?;Required generators: make_classification, make_regression, make_blobs, make_moons, make_circles, make_swiss_roll, make_s_curve, make_sparse_uncorrelated.
let (x_train, x_test, y_train, y_test) =
train_test_split(&x, &y, 0.2, Some(42))?;The repository should be organized as a Cargo workspace:
ferrolearn/
├── ferrolearn/ # Main crate — re-exports everything
├── ferrolearn-core/ # Traits, errors, Dataset abstractions
├── ferrolearn-linear/ # Linear models
├── ferrolearn-tree/ # Decision trees, random forests, boosting
├── ferrolearn-cluster/ # Clustering algorithms
├── ferrolearn-decomp/ # PCA, SVD, NMF, manifold methods
├── ferrolearn-preprocess/ # All preprocessing transformers
├── ferrolearn-metrics/ # All metrics
├── ferrolearn-model-sel/ # Cross-validation, hyperparameter search
├── ferrolearn-sparse/ # Sparse matrix types and operations
├── ferrolearn-datasets/ # Toy datasets and generators
└── ferrolearn-io/ # Model serialization and ONNX/PMML export
Feature flags on the main crate:
| Feature | Default | Description |
|---|---|---|
full |
No | Enables everything |
blas |
No | Link system BLAS/LAPACK |
cuda |
No | GPU backend via cuBLAS |
polars |
No | polars::DataFrame as Dataset |
arrow |
No | arrow::RecordBatch as Dataset |
onnx |
No | ONNX model export |
rayon |
Yes | Parallel execution |
serde |
Yes | Model serialization |
Note: All correctness, oracle, property-based, statistical equivalence, fuzz, and formal verification requirements are defined authoritatively in Section 20 and are hard release gates. The summaries below are superseded by Section 20.
See Section 20 for the complete mandatory correctness stack. Every algorithm requires oracle fixture tests (20.1), property-based invariant tests (20.2), statistical equivalence benchmarking (20.3), algorithm equivalence documentation (20.4), fuzz targets (20.5), and formal verification of metrics and data structures (20.6).
Use criterion for all performance benchmarks. Benchmarks must run against:
- Small (100 u00d7 10), medium (10k u00d7 100), and large (100k u00d7 1000) datasets
- Both
f32andf64 - Dense and sparse inputs where applicable
Benchmark targets: match or exceed scikit-learn+NumPy throughput on CPU for all classical algorithms. Performance benchmarks do not gate releases independently — correctness always takes precedence.
Every public item must have:
- A one-line summary
- A description of the algorithm and its time/space complexity
- Parameter documentation with valid ranges and defaults
- At least one complete, runnable
# Examplesblock - A
# Referencessection with the authoritative paper or textbook citation
The crate targets Rust 1.75 (stable, December 2023) as the MSRV to maximize compatibility while leveraging impl Trait in return position, async fn in traits (where needed), and const generics.
- Core traits (
Fit,Predict,Transform,FitTransform) - Dense matrix integration with
ndarrayandfaer - Sparse matrix types (CSR, CSC, COO)
StandardScaler,MinMaxScaler,RobustScalerOneHotEncoder,LabelEncodertrain_test_split,KFold,StratifiedKFold- All regression metrics, all classification metrics
- Logistic Regression, Linear Regression, Ridge, Lasso
Pipeline(dynamic dispatch variant)
- Decision Tree (classification + regression)
- Random Forest (classification + regression)
- k-Nearest Neighbors
- Naive Bayes (all variants)
- SVM (linear; kernel as stretch goal)
- k-Means, DBSCAN
- PCA, TruncatedSVD
GridSearchCV,RandomizedSearchCV- Model serialization (native format)
- Toy datasets + generators
- Gradient Boosting (including histogram variant)
- Full unsupervised suite (GMM, HDBSCAN, agglomerative)
- Full dimensionality reduction suite (t-SNE, UMAP, NMF)
- All imputers
- Full feature selection suite
TimeSeriesSplit- ONNX export
- BLAS backend
- Polars + Arrow integration
- GPU backend (CUDA)
- Online/streaming learning API for estimators that support it
- Calibration (
CalibratedClassifierCV) - Semi-supervised learning
ColumnTransformerequivalent with column selection frompolars
This section defines the mandatory correctness infrastructure that must ship alongside every algorithm. Passing all layers described here is a hard release gate — no algorithm may be published to crates.io without satisfying all applicable tiers. The goal is to make provable accuracy parity with scikit-learn a first-class, auditable property of the library, not an informal claim.
There are three distinct kinds of correctness that must each be addressed, and no single technique covers all three:
- Numerical correctness — does the implementation produce the same floating-point results as scikit-learn on identical inputs?
- Algorithmic correctness — does the procedure satisfy its mathematical specification as a logical invariant, independent of any reference implementation?
- Statistical correctness — does the fitted model achieve equivalent predictive quality on real-world data?
The six verification layers below address these in combination.
Every algorithm must have a fixture suite generated from scikit-learn's output and committed to the repository. Fixtures are the authoritative ground truth for numerical correctness.
Generation: A Python script (scripts/generate_fixtures.py) runs scikit-learn with fixed random seeds on a curated set of inputs and writes the results to JSON files under fixtures/. This script must be re-run and fixtures re-committed whenever scikit-learn releases a new version that changes numerical behavior.
# Example fixture generation for LogisticRegression
import numpy as np, json
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
rng = np.random.RandomState(42)
X, y = make_classification(n_samples=200, n_features=10, random_state=42)
model = LogisticRegression(C=1.0, max_iter=1000, random_state=42).fit(X, y)
fixture = {
"sklearn_version": sklearn.__version__,
"X": X.tolist(),
"y": y.tolist(),
"coef": model.coef_.tolist(),
"intercept": model.intercept_.tolist(),
"predictions": model.predict(X).tolist(),
"probabilities": model.predict_proba(X).tolist(),
}
with open("fixtures/logistic_regression_basic.json", "w") as f:
json.dump(fixture, f)// Corresponding Rust oracle test
#[test]
fn logistic_regression_matches_sklearn_basic() {
let fixture: Fixture = load_fixture("logistic_regression_basic.json");
let model = LogisticRegression::new()
.C(1.0)
.max_iter(1000)
.random_state(42)
.fit(&fixture.X, &fixture.y)
.unwrap();
assert_ulps_eq!(model.coef(), fixture.coef, max_ulps = 4);
assert_ulps_eq!(model.intercept(), fixture.intercept, max_ulps = 4);
assert_eq!(model.predict(&fixture.X).unwrap(), fixture.predictions);
assert_ulps_eq!(model.predict_proba(&fixture.X).unwrap(), fixture.probabilities, max_ulps = 4);
}Tolerance standard: All floating-point comparisons must use ULP-based tolerances, not absolute epsilon. The budget is:
| Output type | Max ULP tolerance |
|---|---|
| Fitted coefficients / weights | 4 ULPs |
| Predictions (class labels) | Exact match |
| Probabilities / scores | 4 ULPs |
| Metrics (accuracy, R², etc.) | 4 ULPs |
| Iterative solver outputs (>1000 iterations) | 10 ULPs |
The rationale for ULP-based comparison: pure absolute epsilon tolerances do not scale correctly across the floating-point range and will produce false passes for large values and false failures for small ones. ULP differences measure how many representable floating-point numbers separate the two values, which is the correct measure of floating-point rounding divergence.
Required fixture scenarios per algorithm:
- Standard case (well-conditioned data, default hyperparameters)
- Non-default hyperparameters (at least three distinct configurations)
- Edge case: minimum viable input (e.g.
n_samples = n_features + 1) - Edge case: single-class or near-degenerate data (where applicable)
- Edge case: features with zero variance
- Edge case: very large values (
Xscaled to[1e6, 1e7]) - Edge case: very small values (
Xscaled to[1e-7, 1e-6]) - Sparse input equivalent (for all algorithms that accept sparse input)
CI enforcement: Fixture tests run on every pull request. A PR that changes algorithm output without a corresponding fixture update must be rejected by CI.
Property-based tests verify mathematical invariants that must hold for all valid inputs, not just the fixture inputs. They use proptest to generate randomized inputs and assert logical properties derived from the algorithm's mathematical definition.
Unlike oracle tests, property tests do not require access to scikit-learn — they test correctness against the mathematical specification directly. They catch classes of bugs that oracle tests cannot, including failures on inputs that the fixture author did not anticipate.
Tooling: proptest crate. Strategies must be defined for each algorithm's valid input domain, with appropriate shrinking so that counterexamples are minimized on failure.
Required invariants by category:
Scalers and preprocessors:
// StandardScaler: transformed training data must have zero mean and unit variance
proptest! {
fn scaler_zero_mean(matrix in valid_matrix_f64()) {
let (_, t) = StandardScaler::new().fit_transform(&matrix).unwrap();
let means = t.mean_axis(Axis(0)).unwrap();
prop_assert!(means.iter().all(|m| m.abs() < 1e-10));
}
fn scaler_unit_variance(matrix in valid_matrix_f64()) {
let (_, t) = StandardScaler::new().fit_transform(&matrix).unwrap();
let stds = t.std_axis(Axis(0), 0.0);
prop_assert!(stds.iter().all(|s| (s - 1.0).abs() < 1e-10));
}
// fit_transform(X) == fit(X).transform(X) for all valid X
fn scaler_fit_transform_equivalence(matrix in valid_matrix_f64()) {
let (fitted, t1) = StandardScaler::new().fit_transform(&matrix).unwrap();
let t2 = fitted.transform(&matrix).unwrap();
prop_assert_ulps_eq!(t1, t2, max_ulps = 1);
}
}Classifiers:
// predict_proba rows must sum to 1.0
fn classifier_proba_sums_to_one(X in valid_matrix(), y in valid_labels()) { ... }
// classes() must be sorted and contain exactly the unique labels seen during fit
fn classifier_classes_sorted_and_complete(X in valid_matrix(), y in valid_labels()) { ... }
// predict(X) == argmax(predict_proba(X)) for all probabilistic classifiers
fn predict_consistent_with_proba(X in valid_matrix(), y in valid_labels()) { ... }
// Re-fitting with identical data and seed must produce identical output
fn classifier_deterministic_with_seed(X in valid_matrix(), y in valid_labels()) { ... }Regressors:
// R² on training data for a sufficiently expressive model must be ≥ 0.0
// predict() output shape must match input n_samples
// Predicting a constant target must produce zero residuals for LinearRegressionClustering:
// Every sample must be assigned to its nearest centroid (k-Means hard invariant)
fn kmeans_nearest_centroid_assignment(X in valid_matrix(), k in 2usize..=10) {
let model = KMeans::new(k).fit(&X).unwrap();
let labels = model.labels();
let centroids = model.cluster_centers();
for (i, &label) in labels.iter().enumerate() {
let sample = X.row(i);
let assigned_dist = dist(&sample, ¢roids.row(label));
for (j, centroid) in centroids.rows().into_iter().enumerate() {
prop_assert!(assigned_dist <= dist(&sample, ¢roid) + 1e-10);
}
}
}
// Inertia must be non-negative
// n_iter_ must be ≥ 1Dimensionality reduction:
// PCA components must be orthonormal: components @ components.T ≈ I
fn pca_components_orthonormal(X in valid_matrix(), n_comp in 1usize..=5) { ... }
// Explained variance ratios must be non-negative and sum to ≤ 1.0
fn pca_explained_variance_valid(X in valid_matrix(), n_comp in 1usize..=5) { ... }
// Reconstruction error must decrease monotonically as n_components increases
fn pca_reconstruction_monotone(X in valid_matrix()) { ... }Pipelines:
// Pipeline.fit(X, y).predict(X) must equal Pipeline.fit_predict(X, y)
// A pipeline with a no-op transformer inserted must produce identical predictions
// fit_transform on a pipeline must equal fit then transformMetrics:
// accuracy_score must be in [0.0, 1.0] for all inputs
// confusion_matrix rows must sum to the count of each true label
// r2_score on training data of LinearRegression must be ≥ r2_score of a constant predictorCoverage requirement: Every public algorithm must have a minimum of 8 distinct property tests covering its core mathematical invariants. PRs introducing a new algorithm without the full property test suite must be rejected.
Oracle tests prove numerical closeness on fixture data. Statistical benchmarking proves that predictive quality is equivalent on real, diverse datasets. This is a separate concern — an implementation could match all fixture outputs to within 1 ULP and still underperform scikit-learn on novel data due to a subtle algorithmic difference in solver convergence or initialization.
Infrastructure: A Python harness (benchmarks/statistical_equivalence.py) runs both implementations on the same datasets and applies Welch's t-test to cross-validated scores. A Rust binary (benchmarks/ferrolearn_bench) outputs scores as JSON for the harness to consume.
import scipy.stats, json, subprocess
from sklearn.model_selection import cross_val_score
DATASETS = [
# scikit-learn built-ins
"iris", "digits", "wine", "breast_cancer", "diabetes",
# OpenML benchmarks — diverse domains and sizes
"credit-g", # credit scoring, 1000 samples, 20 features
"adult", # income classification, 48k samples
"covertype", # multiclass, 500k samples
"bank-marketing", # imbalanced binary classification
"california_housing", # regression
"ames_housing", # regression, high-dimensional
]
ALGORITHMS = [
"logistic_regression", "random_forest", "decision_tree",
"knn", "naive_bayes_gaussian", "sgd_classifier",
"linear_regression", "ridge", "lasso", "svr",
"kmeans", # evaluated by silhouette score
"pca", # evaluated by reconstruction error on held-out data
]
def run_equivalence_test(algo, dataset, cv=10):
sklearn_scores = cross_val_score(get_sklearn(algo), X, y, cv=cv, scoring=metric)
result = subprocess.run(
["./target/release/ferrolearn_bench", algo, dataset, str(cv)],
capture_output=True
)
ferrolearn_scores = json.loads(result.stdout)["scores"]
t_stat, p_value = scipy.stats.ttest_ind(sklearn_scores, ferrolearn_scores)
# FAIL: ferrolearn is statistically significantly worse (one-sided, α=0.05)
assert not (p_value < 0.05 and mean(ferrolearn_scores) < mean(sklearn_scores)), (
f"FAIL: {algo} on {dataset}: ferrolearn significantly worse "
f"(ferrolearn={mean(ferrolearn_scores):.4f}, sklearn={mean(sklearn_scores):.4f}, p={p_value:.4f})"
)
# WARN: ferrolearn is meaningfully worse but not statistically significant
if mean(ferrolearn_scores) < mean(sklearn_scores) - 0.005:
print(f"WARN: {algo} on {dataset}: ferrolearn mean {mean(ferrolearn_scores):.4f} "
f"vs sklearn {mean(sklearn_scores):.4f} — investigate")Hard requirements:
- The benchmark suite must cover a minimum of 10 datasets and all P0 algorithms before a 1.0 release
- No P0 algorithm may produce a statistically significantly worse result (Welch's t-test, α = 0.05, one-sided) than scikit-learn on any benchmark dataset
- Results of each benchmark run must be committed as a machine-readable artifact (
benchmarks/results/YYYY-MM-DD.json) so regressions can be detected over time - Benchmarks must run on every release candidate and on any PR that touches algorithm implementation code
For each algorithm, a structured documentation block must be maintained alongside the implementation that proves the Rust implementation uses the same algorithm variant as scikit-learn, not merely a similar one. This is the paper trail that converts "it seems to match" into "it implements the same procedure."
The documentation must be kept in docs/algorithm_equivalence/ as one Markdown file per algorithm. Each file must contain:
# Algorithm Equivalence: Logistic Regression (L-BFGS solver)
## scikit-learn Reference
- File: `sklearn/linear_model/_logistic.py`
- Commit: `abc1234` (scikit-learn 1.5.0)
- Relevant functions: `_logistic_loss_and_grad`, `_fit_liblinear`
## Mathematical Specification
Loss function: L(w) = -Σ [y_i log(p_i) + (1-y_i) log(1-p_i)] + (C⁻¹/2) ||w||²
Gradient: ∇L(w) = Xᵀ(p - y) + C⁻¹ w
where p_i = σ(X_i · w + b)
## Equivalence Claims
| Component | sklearn behavior | ferrolearn behavior | Status |
|-------------------|------------------------------------------|----------------------------|--------|
| Loss function | Log-loss with L2 penalty | Identical formulation | ✓ |
| Gradient | Analytical gradient of cross-entropy+L2 | Identical | ✓ |
| Solver | L-BFGS, m=10 history | L-BFGS, m=10 history | ✓ |
| Line search | Wolfe conditions | Wolfe conditions | ✓ |
| Convergence check | |grad|_inf < tol | |grad|_inf < tol | ✓ |
| Multiclass | OvR by default, multinomial w/ lbfgs | Identical dispatch | ✓ |
## Known Numerical Differences
- BLAS implementation differences: max observed 3 ULPs in coef_ on benchmark suite
- FMA instruction availability: CPU-dependent, max 1 ULP difference
- These are unavoidable consequences of floating-point non-associativity, not algorithmic differences
## Fixture Coverage
- `fixtures/logistic_regression_basic.json` — standard case
- `fixtures/logistic_regression_c0.01.json` — strong regularization
- `fixtures/logistic_regression_multinomial.json` — multiclass softmax
- `fixtures/logistic_regression_sparse.json` — CSR input
- `fixtures/logistic_regression_illconditioned.json` — near-singular XEnforcement: No algorithm PR may be merged without the corresponding equivalence document. The document must be reviewed by a maintainer with numerical analysis familiarity before merge.
Fuzzing verifies that no combination of inputs — however malformed, adversarial, or degenerate — causes a panic, undefined behavior, silent NaN propagation, or infinite loop. This is distinct from correctness: the fuzz target's contract is that the library either returns a valid Result::Ok or a well-typed Result::Err, never panics or hangs.
Tooling: cargo-fuzz with libFuzzer backend. Fuzz targets live in fuzz/fuzz_targets/.
// fuzz/fuzz_targets/standard_scaler.rs
#![no_main]
use libfuzzer_sys::fuzz_target;
use ferrolearn::preprocess::StandardScaler;
fuzz_target!(|data: &[u8]| {
if let Some(matrix) = Matrix::from_fuzz_bytes(data) {
// Contract: must never panic. May return Err.
let result = StandardScaler::new().fit(&matrix);
if let Ok(scaler) = result {
let _ = scaler.transform(&matrix);
}
}
});Hard requirements:
- Every public-facing
fit,transform,predict, andfit_transformfunction must have a fuzz target - The fuzz corpus must be run for a minimum of 24 CPU-hours before any release
- Any panic discovered by fuzzing is a P0 bug that blocks release regardless of input validity
- NaN propagation through any public function on non-NaN input is a P0 bug
- Infinite loops or hangs (detectable via timeout) are P0 bugs
- The fuzz corpus (seed inputs) must be committed to
fuzz/corpus/and grown over time - Crashes found by fuzzing must produce a minimized reproducer and a regression test before the fix is merged
Specific adversarial inputs that must be covered by the fuzz seed corpus:
- All-zero matrix
- All-NaN matrix (must return
Err, not panic) - Matrix with a single row
- Matrix with a single column
- Matrix where
n_samples < n_features - Extremely large values (
f64::MAX) - Extremely small values (
f64::MIN_POSITIVE) - Values near overflow boundaries
- Matrices with duplicate rows
- Perfectly collinear feature matrices
20.6 Layer 6 — Formal Verification of Core Primitives (REQUIRED for metrics and data structures; RECOMMENDED for algorithm kernels)
Full formal verification of all ML algorithms is not yet practical with current tooling. However, specific components are tractable for formal verification today and must be verified before the 1.0 release. Formal verification here means machine-checked proofs, not just testing.
Mandatory formally verified components:
Metric functions — All metric functions in ferrolearn-metrics are pure functions with unambiguous mathematical definitions. They must be verified using Prusti (the Rust formal verifier) with pre/postcondition annotations:
#[requires(y_true.len() == y_pred.len())]
#[requires(y_true.len() > 0)]
#[ensures(0.0 <= result && result <= 1.0)]
pub fn accuracy_score(y_true: &[usize], y_pred: &[usize]) -> f64 {
let correct = y_true.iter().zip(y_pred).filter(|(a, b)| a == b).count();
correct as f64 / y_true.len() as f64
}Sparse matrix structural invariants — The CsrMatrix<T> type must carry Prusti-verified invariants proving that after any public constructor or mutation:
indptr.len() == n_rows + 1indptris monotonically non-decreasingindices[indptr[i]..indptr[i+1]]are all in0..n_colsfor every rowidata.len() == indices.len()- No index appears twice within the same row's slice
Type-system proofs (already enforced by the compiler):
The compile-time guarantee that predict() cannot be called on an unfitted model — enforced by the Fit/Predict trait split described in Section 4.1 — is a formal proof carried by Rust's type checker. This must be documented explicitly as a correctness guarantee in the API documentation, not treated as an implementation detail.
Recommended formally verified components (stretch goals):
- SGD update step: prove that a single gradient step reduces the objective by at least the theoretically guaranteed amount for a given learning rate and Lipschitz constant
- Convergence criterion: prove that the L-BFGS convergence check
|grad|_inf < tolis equivalent to the stated stopping condition train_test_split: prove that the returned index sets are disjoint and their union equals0..n_samples
Tooling: Prusti for Rust-native proofs. For algorithms where Prusti's current capabilities are insufficient, the mathematical core may be specified in Lean 4 as a reference specification, with the Rust implementation proven equivalent via fixture tests against the Lean-evaluated reference.
The following table defines what must be in place before an algorithm can ship at each stability level:
| Requirement | Alpha | Beta | Stable 1.0 |
|---|---|---|---|
| Oracle fixtures (standard case) | ✓ | ✓ | ✓ |
| Oracle fixtures (all edge cases) | ✓ | ✓ | |
| Property-based invariant tests (≥ 8) | ✓ | ✓ | |
| Algorithm equivalence document | ✓ | ✓ | |
| Statistical benchmark suite | ✓ | ||
| Fuzz target exists | ✓ | ✓ | ✓ |
| Fuzz corpus run for ≥ 24 CPU-hours | ✓ | ||
| Formal verification (metrics / data structures) | ✓ | ||
| No open P0 correctness bugs | ✓ | ✓ |
Definition of a P0 correctness bug: any condition where the library produces a result that differs from the algorithm's mathematical specification by more than the allowed ULP budget, panics on valid input, propagates NaN on non-NaN input, hangs, or produces output that fails a required property invariant.
No release may proceed with an open P0 correctness bug under any circumstances.
| Crate | Relationship |
|---|---|
linfa |
Closest existing attempt; provides some algorithms but lacks preprocessing, sparse support, cross-validation, and a unified pipeline. ferrolearn should consider linfa for algorithm contributions or offer migration paths. |
ndarray |
Primary dense array type; ferrolearn depends on it |
faer |
High-performance linear algebra; used for SVD, solvers |
sprs |
Existing sparse matrix crate; ferrolearn-sparse may wrap or replace it |
polars |
DataFrame integration via feature flag |
smartcore |
Another incomplete scikit-learn attempt; referenced for API decisions |
tract |
ONNX runtime; target for exported models |
rayon |
All CPU parallelism |
serde |
Serialization of all model types |
This document is a living specification. Open questions and tradeoffs should be resolved in tracked issues before implementation begins.