44
55# iTransformer: Rust Implementation
66
7- An ** iTransformer** implementation in ** Rust** , inspired by the [ lucidrains iTransformer repository] ( https://github.com/lucidrains/iTransformer ) , and based on the original research and implementation from [ Tsinghua University's iTransformer repository] ( https://github.com/thuml/iTransformer ) .
7+ An ** iTransformer** implementation in ** Rust** using [ Hugging Face Candle ] ( https://github.com/huggingface/candle ) , inspired by the [ lucidrains iTransformer repository] ( https://github.com/lucidrains/iTransformer ) , and based on the original research and implementation from [ Tsinghua University's iTransformer repository] ( https://github.com/thuml/iTransformer ) .
88
99## π ** What is iTransformer?**
1010
@@ -16,7 +16,8 @@ iTransformer introduces an **inverted Transformer architecture** designed for **
1616- ** Flexible Prediction Lengths:** Supports predictions at multiple horizons (e.g., 12, 24, 36, 48 steps ahead).
1717- ** Scalability:** Handles hundreds of variates with efficiency.
1818- ** Zero-Shot Generalization:** Train on partial variates and generalize to unseen variates.
19- - ** Efficient Attention Mechanisms:** Compatible with advanced techniques like FlashAttention.
19+ - ** Multiple Model Variants:** ITransformer, ITransformer2D, and ITransformerFFT.
20+ - ** Hardware Acceleration:** Metal (macOS) and CUDA support via Candle.
2021
2122### π οΈ ** Architecture Overview**
2223
@@ -41,42 +42,154 @@ To get started, ensure you have Rust and Cargo installed. Then:
4142
4243``` bash
4344# Add iTransformer to your project dependencies
44- cargo add itransformer-rs
45+ cargo add itransformer
4546```
4647
48+ For GPU acceleration:
49+ ``` bash
50+ # macOS Metal support
51+ cargo add itransformer --features metal
52+
53+ # CUDA support
54+ cargo add itransformer --features cuda
55+ ```
56+
57+ ## ποΈ ** Model Variants**
58+
59+ This library provides three model variants:
60+
61+ | Model | Description |
62+ | -------| -------------|
63+ | ` ITransformer ` | Base inverted transformer for multivariate time series forecasting |
64+ | ` ITransformer2D ` | Extended variant with granular time attention via ` num_time_tokens ` parameter |
65+ | ` ITransformerFFT ` | Variant with additional Fourier tokens prepended to the attention sequence |
66+
4767## π ** Usage**
4868
49- ``` rust
50- use tch :: {Device , Tensor , nn :: VarStore , Kind };
69+ ### ITransformer (Base Model)
5170
52- fn main () -> Result <(), Box <dyn std :: error :: Error >> {
53- let vs = VarStore :: new (Device :: Cpu );
71+ ``` rust
72+ use candle_core :: {Device , DType , Tensor };
73+ use candle_nn :: {VarBuilder , VarMap };
74+ use itransformer :: ITransformer ;
75+
76+ fn main () -> candle_core :: Result <()> {
77+ let device = Device :: Cpu ;
78+ let varmap = VarMap :: new ();
79+ let vb = VarBuilder :: from_varmap (& varmap , DType :: F32 , & device );
80+
5481 let model = ITransformer :: new (
55- & ( vs . root () / " itransformer " ) ,
56- 137 , // num_variates
57- 96 , // lookback_len
58- 6 , // depth
59- 256 , // dim
60- Some (1 ), // num_tokens_per_variate
61- vec! [12 , 24 , 36 , 48 ], // pred_length
62- Some (64 ), // dim_head
63- Some (8 ), // heads
64- None , // attn_drop_p
65- None , // ff_mult
66- None , // ff_drop_p
67- None , // num_mem_tokens
68- Some (true ), // use_reversible_instance_norm
69- None , // reversible_instance_norm_affine
70- false , // flash_attn
71- & Device :: Cpu ,
82+ vb ,
83+ 137 , // num_variates
84+ 96 , // lookback_len
85+ 6 , // depth
86+ 256 , // dim
87+ Some (1 ), // num_tokens_per_variate
88+ vec! [12 , 24 , 36 , 48 ], // pred_length
89+ Some (64 ), // dim_head
90+ Some (8 ), // heads
91+ None , // attn_drop_p
92+ None , // ff_mult
93+ None , // ff_drop_p
94+ None , // num_mem_tokens
95+ Some (true ), // use_reversible_instance_norm
96+ None , // reversible_instance_norm_affine
97+ false , // flash_attn
98+ & device ,
7299 )? ;
73- let time_series = Tensor :: randn ([2 , 96 , 137 ], (Kind :: Float , Device :: Cpu ));
74- let preds = model . forward (& time_series , None , false );
100+
101+ let time_series = Tensor :: randn (0f32 , 1f32 , (2 , 96 , 137 ), & device )? ;
102+ let preds = model . forward (& time_series , None , false )? ;
75103 println! (" {:?}" , preds );
76104 Ok (())
77105}
78106```
79107
108+ ### ITransformer2D (Time Token Variant)
109+
110+ ``` rust
111+ use candle_core :: {Device , DType , Tensor };
112+ use candle_nn :: {VarBuilder , VarMap };
113+ use itransformer :: ITransformer2D ;
114+
115+ fn main () -> candle_core :: Result <()> {
116+ let device = Device :: Cpu ;
117+ let varmap = VarMap :: new ();
118+ let vb = VarBuilder :: from_varmap (& varmap , DType :: F32 , & device );
119+
120+ let model = ITransformer2D :: new (
121+ vb ,
122+ 137 , // num_variates
123+ 96 , // lookback_len
124+ 4 , // depth
125+ 256 , // dim
126+ 8 , // num_time_tokens (lookback_len must be divisible by this)
127+ vec! [12 , 24 ], // pred_length
128+ Some (32 ), // dim_head
129+ Some (4 ), // heads
130+ None , None , None , None , None , None ,
131+ false , // flash_attn
132+ & device ,
133+ )? ;
134+
135+ let time_series = Tensor :: randn (0f32 , 1f32 , (2 , 96 , 137 ), & device )? ;
136+ let preds = model . forward (& time_series , None , false )? ;
137+ println! (" {:?}" , preds );
138+ Ok (())
139+ }
140+ ```
141+
142+ ### ITransformerFFT (Fourier Token Variant)
143+
144+ ``` rust
145+ use candle_core :: {Device , DType , Tensor };
146+ use candle_nn :: {VarBuilder , VarMap };
147+ use itransformer :: ITransformerFFT ;
148+
149+ fn main () -> candle_core :: Result <()> {
150+ let device = Device :: Cpu ;
151+ let varmap = VarMap :: new ();
152+ let vb = VarBuilder :: from_varmap (& varmap , DType :: F32 , & device );
153+
154+ let model = ITransformerFFT :: new (
155+ vb ,
156+ 137 , // num_variates
157+ 96 , // lookback_len
158+ 4 , // depth
159+ 256 , // dim
160+ Some (1 ), // num_tokens_per_variate
161+ 4 , // num_fft_tokens
162+ vec! [12 , 24 ], // pred_length
163+ Some (32 ), // dim_head
164+ Some (4 ), // heads
165+ None , None , None , None , None , None ,
166+ false , // flash_attn
167+ & device ,
168+ )? ;
169+
170+ let time_series = Tensor :: randn (0f32 , 1f32 , (2 , 96 , 137 ), & device )? ;
171+ let preds = model . forward (& time_series , None , false )? ;
172+ println! (" {:?}" , preds );
173+ Ok (())
174+ }
175+ ```
176+
177+ ## π ** Project Structure**
178+
179+ ```
180+ src/
181+ βββ attend.rs # Attention computation
182+ βββ attention.rs # ToQKV, ToValueResidualMix, ToVGates, ToOut, Attention
183+ βββ feedforward.rs # GEGLU, FeedForward
184+ βββ itransformer.rs # Base ITransformer
185+ βββ itransformer2d.rs # ITransformer2D with time tokens
186+ βββ itransformer_fft.rs # ITransformerFFT with Fourier tokens
187+ βββ lib.rs # Module declarations and re-exports
188+ βββ mlp_in.rs # Input projection layer
189+ βββ pred_head.rs # Prediction heads
190+ βββ revin.rs # Reversible Instance Normalization
191+ ```
192+
80193
81194## π ** References**
82195
0 commit comments