|
| 1 | +use std::{io::BufReader, fs::File}; |
| 2 | + |
| 3 | +use juice::{layers::*, capnp_util}; |
| 4 | +use juice::layer::{LayerType, LayerConfig}; |
| 5 | +use juice::capnp_util::{CapnpRead, CapnpWrite}; |
| 6 | + |
| 7 | +fn main() { |
| 8 | + let mut cfg = SequentialConfig::default(); |
| 9 | + cfg.add_input("data", &[64, 3, 224, 224]); |
| 10 | + |
| 11 | + let conv1_layer_cfg = ConvolutionConfig { |
| 12 | + num_output: 64, |
| 13 | + filter_shape: vec![3], |
| 14 | + padding: vec![1], |
| 15 | + stride: vec![1], |
| 16 | + }; |
| 17 | + cfg.add_layer(LayerConfig::new("conv1", conv1_layer_cfg)); |
| 18 | + cfg.add_layer(LayerConfig::new("conv1/relu", LayerType::ReLU)); |
| 19 | + cfg.add_layer(LayerConfig::new( |
| 20 | + "pool1", |
| 21 | + PoolingConfig { |
| 22 | + mode: PoolingMode::Max, |
| 23 | + filter_shape: vec![2], |
| 24 | + stride: vec![2], |
| 25 | + padding: vec![0], |
| 26 | + }, |
| 27 | + )); |
| 28 | + |
| 29 | + cfg.add_layer(LayerConfig::new( |
| 30 | + "conv2", |
| 31 | + ConvolutionConfig { |
| 32 | + num_output: 128, |
| 33 | + filter_shape: vec![3], |
| 34 | + padding: vec![1], |
| 35 | + stride: vec![1], |
| 36 | + }, |
| 37 | + )); |
| 38 | + cfg.add_layer(LayerConfig::new("conv2/relu", LayerType::ReLU)); |
| 39 | + cfg.add_layer(LayerConfig::new("fc1", LinearConfig { output_size: 2000 })); |
| 40 | + cfg.add_layer(LayerConfig::new("fc2", LinearConfig { output_size: 100 })); |
| 41 | + cfg.add_layer(LayerConfig::new("fc3", LinearConfig { output_size: 2 })); |
| 42 | + |
| 43 | + |
| 44 | + let p = "./foo.serialized.capnp"; |
| 45 | + { |
| 46 | + let mut f = File::options().truncate(true).create(true).write(true).open(p).unwrap(); |
| 47 | + // let mut builder = juice::juice_capnp::sequential_config::Builder; |
| 48 | + let mut builder = capnp::message::TypedBuilder::<juice::juice_capnp::sequential_config::Owned>::new_default(); |
| 49 | + let facade = &mut builder.get_root().unwrap(); |
| 50 | + cfg.write_capnp(facade); |
| 51 | + |
| 52 | + capnp::serialize::write_message(&mut f, builder.borrow_inner()).unwrap(); |
| 53 | + } |
| 54 | + let reincarnation = { |
| 55 | + let f = File::options().read(true).open(p).unwrap(); |
| 56 | + |
| 57 | + let reader = BufReader::new(f); |
| 58 | + let reader = capnp::serialize::try_read_message( |
| 59 | + reader, |
| 60 | + capnp::message::ReaderOptions { |
| 61 | + traversal_limit_in_words: None, |
| 62 | + nesting_limit: 100, |
| 63 | + }).unwrap().unwrap(); |
| 64 | + <SequentialConfig as CapnpRead>::read_capnp(reader.get_root().unwrap()) |
| 65 | + }; |
| 66 | + |
| 67 | + assert_eq!(dbg!(cfg), dbg!(reincarnation)); |
| 68 | + |
| 69 | +} |
0 commit comments