Skip to content

Commit f34dbcb

Browse files
Feat/adam w (#237)
* feat[optimiser]: add more optimiser errors * feat[optimisers]: add adamW * fix[optimiser]: tests * fix[optimisers]: fix more tests * optimiser[adamw]: tests pass * chore: clippy --------- Co-authored-by: Marcus Cvjeticanin <[email protected]>
1 parent aa57dad commit f34dbcb

File tree

5 files changed

+590
-20
lines changed

5 files changed

+590
-20
lines changed

delta/src/classical_ml/losses.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ where
108108
.zip(actuals.iter())
109109
.map(|(p, y)| {
110110
let p_clamped = p.max(epsilon).min(T::one() - epsilon);
111-
-(y.clone() * p_clamped.ln() + (T::one() - y.clone()) * (T::one() - p_clamped).ln())
111+
-(*y * p_clamped.ln() + (T::one() - *y) * (T::one() - p_clamped).ln())
112112
})
113113
.sum::<T>()
114114
/ m

delta/src/deep_learning/dataset/vision/cifar10.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl Cifar10Dataset {
9595
let path = entry.path().unwrap().into_owned();
9696
let file_name = path.file_name().unwrap().to_string_lossy().to_string();
9797

98-
if path.is_dir() || !path.extension().map_or(false, |ext| ext == "bin") {
98+
if path.is_dir() || path.extension().is_none_or(|ext| ext != "bin") {
9999
continue;
100100
}
101101

delta/src/deep_learning/dataset/vision/cifar100.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl Cifar100Dataset {
8989
let path = entry.path().unwrap().into_owned();
9090
let file_name = path.file_name().unwrap().to_string_lossy().to_string();
9191

92-
if path.is_dir() || !path.extension().map_or(false, |ext| ext == "bin") {
92+
if path.is_dir() || path.extension().is_none_or(|ext| ext != "bin") {
9393
continue;
9494
}
9595

delta/src/deep_learning/errors.rs

+15
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ pub enum OptimizerError {
5252
IncompatibleGradientWeightShape(Vec<usize>, Vec<usize>),
5353
/// Error when epsilon is not set or is invalid.
5454
InvalidEpsilon(String),
55+
/// Error when beta parameter is invalid.
56+
InvalidBeta(String),
57+
/// Error when weight decay parameter is invalid.
58+
InvalidWeightDecay(String),
59+
/// Error when gradient contains invalid values (NaN or Inf).
60+
InvalidGradient(String),
61+
/// Error when weight contains invalid values (NaN or Inf).
62+
InvalidWeight(String),
63+
/// Error when shapes don't match
64+
ShapeMismatch(String),
5565
}
5666

5767
/// An enumeration of possible errors that can occur in a model.
@@ -105,6 +115,11 @@ impl fmt::Display for OptimizerError {
105115
write!(f, "Gradient shape {:?} is incompatible with weight shape {:?}", g, w)
106116
}
107117
OptimizerError::InvalidEpsilon(s) => write!(f, "{}", s),
118+
OptimizerError::InvalidBeta(s) => write!(f, "{}", s),
119+
OptimizerError::InvalidWeightDecay(s) => write!(f, "{}", s),
120+
OptimizerError::InvalidGradient(s) => write!(f, "{}", s),
121+
OptimizerError::InvalidWeight(s) => write!(f, "{}", s),
122+
OptimizerError::ShapeMismatch(s) => write!(f, "Shape mismatch: {}", s),
108123
}
109124
}
110125
}

0 commit comments

Comments
 (0)