This is an implementation of Spatial-Temporal Continuous Dynamics Network (STCDN), a neural network architecture for traffic forecasting that combines Graph Attention Networks (GAT) with Neural Ordinary Differential Equations (Neural ODEs) to model complex spatial-temporal dependencies in traffic data.
Traditional traffic forecasting models struggle with capturing complex spatial-temporal dependencies. STCDN addresses this by:
- Using Neural ODEs to model continuous temporal dynamics
- Employing Graph Attention Networks to capture spatial dependencies
- Learning adaptive graph structures from data
- Supporting multiple traffic datasets (PEMS series)
- Continuous Time Modeling: Neural ODEs model temporal evolution as continuous processes
- Graph Neural Networks: GAT layers capture complex spatial relationships between sensors
- Adaptive Graph Learning: Learn connections between nodes rather than relying solely on predefined graphs
- Multi-Dataset Support: Works with various PEMS traffic datasets
The model follows an encoder-decoder architecture:
- Encoder: Processes historical traffic data using GAT layers within an ODE framework
- Hidden Transformation: Transforms encoder output to decoder input
- Decoder: Generates future traffic predictions using another ODE-GAT module
- Python 3.7+
- PyTorch 1.8+
- DGL (Deep Graph Library)
- torchdiffeq
- NumPy
- SciPy
Install dependencies with:
pip install -r requirements.txtTo train the model on a PEMS dataset:
python main.py --data PEMS-D3 --device cuda:0 --lr 0.003 --epochs 500Key arguments:
--data: Dataset to use (PEMS-D3, PEMS-D4, PEMS-D7, PEMS-D8, etc.)--device: Device to use (cuda:0, cpu, mps)--lr: Learning rate--epochs: Number of training epochs
To evaluate a trained model:
python test.py --data ./data/PEMS-D3 --device cuda:0The model expects data in .npz format with 'x' and 'y' arrays for input and output sequences.
Use generate_training_data.py to convert raw traffic data:
python generate_training_data.py --output_dir data/PEMS-D3 --traffic_df_filename data/PEMS03.npzSTCDN/
├── main.py # Training script
├── model.py # Main model definition
├── encoder.py # Encoder with ODE-GAT
├── decoder.py # Decoder with ODE-GAT
├── gat.py # Graph Attention Network implementation
├── graph_learner.py # Adaptive graph learning module
├── holder.py # Training controller
├── utils.py # Utility functions
├── test.py # Evaluation script
├── generate_training_data.py # Data preprocessing
├── requirements.txt # Python dependencies
└── README.md # This file
If you use this code in your research, please cite the original paper (to be added).
MIT License (see LICENSE file for details)