forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.rs
More file actions
89 lines (71 loc) · 2.67 KB
/
training.rs
File metadata and controls
89 lines (71 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use crate::dataset::{HousingBatcher, HousingDataset};
use crate::model::RegressionModelConfig;
use burn::optim::AdamConfig;
use burn::{
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{metric::LossMetric, LearnerBuilder},
};
#[derive(Config)]
pub struct ExpConfig {
#[config(default = 100)]
pub num_epochs: usize,
#[config(default = 2)]
pub num_workers: usize,
#[config(default = 1337)]
pub seed: u64,
pub optimizer: AdamConfig,
#[config(default = 256)]
pub batch_size: usize,
}
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn run<B: AutodiffBackend>(artifact_dir: &str, device: B::Device) {
create_artifact_dir(artifact_dir);
// Config
let optimizer = AdamConfig::new();
let config = ExpConfig::new(optimizer);
let model = RegressionModelConfig::new().init(&device);
B::seed(config.seed);
// Define train/valid datasets and dataloaders
let train_dataset = HousingDataset::train();
let valid_dataset = HousingDataset::validation();
println!("Train Dataset Size: {}", train_dataset.len());
println!("Valid Dataset Size: {}", valid_dataset.len());
let batcher_train = HousingBatcher::<B>::new(device.clone());
let batcher_test = HousingBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(train_dataset);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(valid_dataset);
// Model
let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(model, config.optimizer.init(), 1e-3);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{artifact_dir}/config.json").as_str())
.unwrap();
model_trained
.save_file(
format!("{artifact_dir}/model"),
&NoStdTrainingRecorder::new(),
)
.expect("Failed to save trained model");
}