Skip to content

Xieyyyy/STCDN

Repository files navigation

Spatial-Temporal Continuous Dynamics Network (STCDN)

This is an implementation of Spatial-Temporal Continuous Dynamics Network (STCDN), a neural network architecture for traffic forecasting that combines Graph Attention Networks (GAT) with Neural Ordinary Differential Equations (Neural ODEs) to model complex spatial-temporal dependencies in traffic data.

Overview

Traditional traffic forecasting models struggle with capturing complex spatial-temporal dependencies. STCDN addresses this by:

  1. Using Neural ODEs to model continuous temporal dynamics
  2. Employing Graph Attention Networks to capture spatial dependencies
  3. Learning adaptive graph structures from data
  4. Supporting multiple traffic datasets (PEMS series)

Key Features

  • Continuous Time Modeling: Neural ODEs model temporal evolution as continuous processes
  • Graph Neural Networks: GAT layers capture complex spatial relationships between sensors
  • Adaptive Graph Learning: Learn connections between nodes rather than relying solely on predefined graphs
  • Multi-Dataset Support: Works with various PEMS traffic datasets

Architecture

The model follows an encoder-decoder architecture:

  1. Encoder: Processes historical traffic data using GAT layers within an ODE framework
  2. Hidden Transformation: Transforms encoder output to decoder input
  3. Decoder: Generates future traffic predictions using another ODE-GAT module

Requirements

  • Python 3.7+
  • PyTorch 1.8+
  • DGL (Deep Graph Library)
  • torchdiffeq
  • NumPy
  • SciPy

Install dependencies with:

pip install -r requirements.txt

Usage

Training

To train the model on a PEMS dataset:

python main.py --data PEMS-D3 --device cuda:0 --lr 0.003 --epochs 500

Key arguments:

  • --data: Dataset to use (PEMS-D3, PEMS-D4, PEMS-D7, PEMS-D8, etc.)
  • --device: Device to use (cuda:0, cpu, mps)
  • --lr: Learning rate
  • --epochs: Number of training epochs

Testing

To evaluate a trained model:

python test.py --data ./data/PEMS-D3 --device cuda:0

Data Preparation

The model expects data in .npz format with 'x' and 'y' arrays for input and output sequences.

Use generate_training_data.py to convert raw traffic data:

python generate_training_data.py --output_dir data/PEMS-D3 --traffic_df_filename data/PEMS03.npz

Project Structure

STCDN/
├── main.py              # Training script
├── model.py             # Main model definition
├── encoder.py           # Encoder with ODE-GAT
├── decoder.py           # Decoder with ODE-GAT
├── gat.py               # Graph Attention Network implementation
├── graph_learner.py     # Adaptive graph learning module
├── holder.py            # Training controller
├── utils.py             # Utility functions
├── test.py              # Evaluation script
├── generate_training_data.py  # Data preprocessing
├── requirements.txt     # Python dependencies
└── README.md            # This file

Citation

If you use this code in your research, please cite the original paper (to be added).

License

MIT License (see LICENSE file for details)

About

The implementation of Spatial-Temporal Contunuous Dynamic Network

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages