Skip to content

Commit 419a0d8

Browse files
committed
feat: update readme
1 parent 26d0b11 commit 419a0d8

1 file changed

Lines changed: 139 additions & 26 deletions

File tree

β€ŽREADME.mdβ€Ž

Lines changed: 139 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# iTransformer: Rust Implementation
66

7-
An **iTransformer** implementation in **Rust**, inspired by the [lucidrains iTransformer repository](https://github.com/lucidrains/iTransformer), and based on the original research and implementation from [Tsinghua University's iTransformer repository](https://github.com/thuml/iTransformer).
7+
An **iTransformer** implementation in **Rust** using [Hugging Face Candle](https://github.com/huggingface/candle), inspired by the [lucidrains iTransformer repository](https://github.com/lucidrains/iTransformer), and based on the original research and implementation from [Tsinghua University's iTransformer repository](https://github.com/thuml/iTransformer).
88

99
## πŸ“š **What is iTransformer?**
1010

@@ -16,7 +16,8 @@ iTransformer introduces an **inverted Transformer architecture** designed for **
1616
- **Flexible Prediction Lengths:** Supports predictions at multiple horizons (e.g., 12, 24, 36, 48 steps ahead).
1717
- **Scalability:** Handles hundreds of variates with efficiency.
1818
- **Zero-Shot Generalization:** Train on partial variates and generalize to unseen variates.
19-
- **Efficient Attention Mechanisms:** Compatible with advanced techniques like FlashAttention.
19+
- **Multiple Model Variants:** ITransformer, ITransformer2D, and ITransformerFFT.
20+
- **Hardware Acceleration:** Metal (macOS) and CUDA support via Candle.
2021

2122
### πŸ› οΈ **Architecture Overview**
2223

@@ -41,42 +42,154 @@ To get started, ensure you have Rust and Cargo installed. Then:
4142

4243
```bash
4344
# Add iTransformer to your project dependencies
44-
cargo add itransformer-rs
45+
cargo add itransformer
4546
```
4647

48+
For GPU acceleration:
49+
```bash
50+
# macOS Metal support
51+
cargo add itransformer --features metal
52+
53+
# CUDA support
54+
cargo add itransformer --features cuda
55+
```
56+
57+
## πŸ—οΈ **Model Variants**
58+
59+
This library provides three model variants:
60+
61+
| Model | Description |
62+
|-------|-------------|
63+
| `ITransformer` | Base inverted transformer for multivariate time series forecasting |
64+
| `ITransformer2D` | Extended variant with granular time attention via `num_time_tokens` parameter |
65+
| `ITransformerFFT` | Variant with additional Fourier tokens prepended to the attention sequence |
66+
4767
## πŸ“ **Usage**
4868

49-
```rust
50-
use tch::{Device, Tensor, nn::VarStore, Kind};
69+
### ITransformer (Base Model)
5170

52-
fn main() -> Result<(), Box<dyn std::error::Error>> {
53-
let vs = VarStore::new(Device::Cpu);
71+
```rust
72+
use candle_core::{Device, DType, Tensor};
73+
use candle_nn::{VarBuilder, VarMap};
74+
use itransformer::ITransformer;
75+
76+
fn main() -> candle_core::Result<()> {
77+
let device = Device::Cpu;
78+
let varmap = VarMap::new();
79+
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
80+
5481
let model = ITransformer::new(
55-
&(vs.root() / "itransformer"),
56-
137, // num_variates
57-
96, // lookback_len
58-
6, // depth
59-
256, // dim
60-
Some(1), // num_tokens_per_variate
61-
vec![12, 24, 36, 48], // pred_length
62-
Some(64), // dim_head
63-
Some(8), // heads
64-
None, // attn_drop_p
65-
None, // ff_mult
66-
None, // ff_drop_p
67-
None, // num_mem_tokens
68-
Some(true), // use_reversible_instance_norm
69-
None, // reversible_instance_norm_affine
70-
false, // flash_attn
71-
&Device::Cpu,
82+
vb,
83+
137, // num_variates
84+
96, // lookback_len
85+
6, // depth
86+
256, // dim
87+
Some(1), // num_tokens_per_variate
88+
vec![12, 24, 36, 48], // pred_length
89+
Some(64), // dim_head
90+
Some(8), // heads
91+
None, // attn_drop_p
92+
None, // ff_mult
93+
None, // ff_drop_p
94+
None, // num_mem_tokens
95+
Some(true), // use_reversible_instance_norm
96+
None, // reversible_instance_norm_affine
97+
false, // flash_attn
98+
&device,
7299
)?;
73-
let time_series = Tensor::randn([2, 96, 137], (Kind::Float, Device::Cpu));
74-
let preds = model.forward(&time_series, None, false);
100+
101+
let time_series = Tensor::randn(0f32, 1f32, (2, 96, 137), &device)?;
102+
let preds = model.forward(&time_series, None, false)?;
75103
println!("{:?}", preds);
76104
Ok(())
77105
}
78106
```
79107

108+
### ITransformer2D (Time Token Variant)
109+
110+
```rust
111+
use candle_core::{Device, DType, Tensor};
112+
use candle_nn::{VarBuilder, VarMap};
113+
use itransformer::ITransformer2D;
114+
115+
fn main() -> candle_core::Result<()> {
116+
let device = Device::Cpu;
117+
let varmap = VarMap::new();
118+
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
119+
120+
let model = ITransformer2D::new(
121+
vb,
122+
137, // num_variates
123+
96, // lookback_len
124+
4, // depth
125+
256, // dim
126+
8, // num_time_tokens (lookback_len must be divisible by this)
127+
vec![12, 24], // pred_length
128+
Some(32), // dim_head
129+
Some(4), // heads
130+
None, None, None, None, None, None,
131+
false, // flash_attn
132+
&device,
133+
)?;
134+
135+
let time_series = Tensor::randn(0f32, 1f32, (2, 96, 137), &device)?;
136+
let preds = model.forward(&time_series, None, false)?;
137+
println!("{:?}", preds);
138+
Ok(())
139+
}
140+
```
141+
142+
### ITransformerFFT (Fourier Token Variant)
143+
144+
```rust
145+
use candle_core::{Device, DType, Tensor};
146+
use candle_nn::{VarBuilder, VarMap};
147+
use itransformer::ITransformerFFT;
148+
149+
fn main() -> candle_core::Result<()> {
150+
let device = Device::Cpu;
151+
let varmap = VarMap::new();
152+
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
153+
154+
let model = ITransformerFFT::new(
155+
vb,
156+
137, // num_variates
157+
96, // lookback_len
158+
4, // depth
159+
256, // dim
160+
Some(1), // num_tokens_per_variate
161+
4, // num_fft_tokens
162+
vec![12, 24], // pred_length
163+
Some(32), // dim_head
164+
Some(4), // heads
165+
None, None, None, None, None, None,
166+
false, // flash_attn
167+
&device,
168+
)?;
169+
170+
let time_series = Tensor::randn(0f32, 1f32, (2, 96, 137), &device)?;
171+
let preds = model.forward(&time_series, None, false)?;
172+
println!("{:?}", preds);
173+
Ok(())
174+
}
175+
```
176+
177+
## πŸ“ **Project Structure**
178+
179+
```
180+
src/
181+
β”œβ”€β”€ attend.rs # Attention computation
182+
β”œβ”€β”€ attention.rs # ToQKV, ToValueResidualMix, ToVGates, ToOut, Attention
183+
β”œβ”€β”€ feedforward.rs # GEGLU, FeedForward
184+
β”œβ”€β”€ itransformer.rs # Base ITransformer
185+
β”œβ”€β”€ itransformer2d.rs # ITransformer2D with time tokens
186+
β”œβ”€β”€ itransformer_fft.rs # ITransformerFFT with Fourier tokens
187+
β”œβ”€β”€ lib.rs # Module declarations and re-exports
188+
β”œβ”€β”€ mlp_in.rs # Input projection layer
189+
β”œβ”€β”€ pred_head.rs # Prediction heads
190+
β”œβ”€β”€ revin.rs # Reversible Instance Normalization
191+
```
192+
80193

81194
## πŸ“– **References**
82195

0 commit comments

Comments
Β (0)