Skip to content

Commit 648a938

Browse files
committed
Model tweaks.
1 parent bc3658c commit 648a938

3 files changed

Lines changed: 8 additions & 19 deletions

File tree

kord/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ ml_train = [
6363
"burn-ndarray/std",
6464
]
6565
ml_infer = ["ml_base", "burn", "burn-ndarray", "burn-ndarray/std"]
66-
ml_gpu = ["ml_train", "burn-tch", "burn-wgpu"]
66+
ml_gpu = ["ml_train", "burn-tch"]
67+
ml_wgpu = ["ml_train", "burn-wgpu"]
6768

6869
wasm = [
6970
"rodio/wasm-bindgen",

kord/src/bin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ fn start(args: Args) -> Void {
453453

454454
klib::ml::train::run_training::<Autodiff<LibTorch<f32>>>(device, &config, true, true)?;
455455
}
456-
#[cfg(feature = "ml_gpu")]
456+
#[cfg(feature = "ml_wgpu")]
457457
"wgpu" => {
458458
use burn_wgpu::{Wgpu, WgpuDevice};
459459

kord/src/ml/base/model.rs

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,12 @@ impl<B: Backend> KordModel<B> {
6868
let output = self.forward(item.samples);
6969

7070
let loss = BinaryCrossEntropyLossConfig::new().with_logits(false).init(&output.device());
71-
let loss = loss.forward(output.clone(), targets.clone());
71+
let mut loss = loss.forward(output.clone(), targets.clone());
7272

73-
// let loss = MeanSquareLoss::default();
74-
// let loss = loss.forward(output.clone(), targets.clone());
75-
76-
// let loss = BinaryCrossEntropyLoss::default();
77-
// let loss = loss.forward(output.clone(), targets.clone());
78-
79-
// let mut loss = FocalLoss::default();
80-
// loss.gamma = 2.0;
81-
// let loss = loss.forward(output.clone(), targets.clone());
82-
83-
//let loss = loss + l1_regularization(self, 1e-4);
84-
85-
// let harmonic_penalty_tensor = get_harmonic_penalty_tensor().to_device(&output.device());
86-
// let harmonic_loss = output.clone().matmul(harmonic_penalty_tensor).sum_dim(0).mean().mul_scalar(0.0001);
87-
88-
// let loss = loss + harmonic_loss;
73+
// Add L1 regularization
74+
// let l1_reg_strength = 1e-4;
75+
// let l1_penalty = self.output.weight.val().abs().sum() * l1_reg_strength;
76+
// loss = loss + l1_penalty;
8977

9078
MultiLabelClassificationOutput { loss, output, targets }
9179
}

0 commit comments

Comments
 (0)