Skip to content

Xieyyyy/SAST-GNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

6 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

SAST-GNN: Self-Attention based Spatio-Temporal Graph Neural Network for Traffic Flow Prediction (DASFAA 2020)

🚦 Project Overview

SAST-GNN is a high-precision traffic flow prediction system based on spatial-temporal graph neural networks. This project combines the advantages of Transformer and Graph Attention Network (GAT) to accurately predict future traffic flow in urban transportation networks.

Core Features

  • Spatial Modeling: Uses GAT to capture spatial correlations in road networks
  • Temporal Modeling: Uses Transformer to capture temporal dependencies in traffic patterns
  • End-to-End: Complete solution from data preprocessing to model deployment
  • High Performance: Supports GPU acceleration, scalable to large-scale traffic networks

πŸ“ Project Structure

SAST-GNN/
β”œβ”€β”€ src/                          # Source code directory
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ main.py                  # Main training script
β”‚   β”œβ”€β”€ data/                    # Data processing module
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   └── dataset.py          # Dataset loading and preprocessing
β”‚   β”œβ”€β”€ models/                  # Model architectures
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ sast_gnn.py        # SAST-GNN main model
β”‚   β”‚   β”œβ”€β”€ gat_dgl.py         # Graph attention network implementation
β”‚   β”‚   β”œβ”€β”€ spatial_temporal.py # Spatial-temporal module
β”‚   β”‚   └── singleTransformerBlock.py # Transformer block
β”‚   β”œβ”€β”€ training/               # Training module
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   └── trainer.py        # Trainer
β”‚   └── utils/                  # Utility functions
β”‚       β”œβ”€β”€ __init__.py
β”‚       β”œβ”€β”€ config.py         # Configuration management
β”‚       β”œβ”€β”€ data_utils.py     # Data utilities
β”‚       └── metrics.py        # Evaluation metrics
β”œβ”€β”€ config/                     # Configuration files
β”‚   └── config.yaml           # Main configuration file
β”œβ”€β”€ data/                      # Data directory
β”‚   β”œβ”€β”€ V_228.csv            # Traffic flow data
β”‚   β”œβ”€β”€ W_228.csv            # Adjacency matrix weights
β”‚   └── PeMSD7_Full.zip      # Complete dataset
β”œβ”€β”€ models/                    # Model save directory
β”œβ”€β”€ logs/                      # Log directory
β”œβ”€β”€ record/                    # Training records
β”œβ”€β”€ requirements.txt          # Dependencies list
└── README.md                # Project documentation

πŸš€ Quick Start

1. Environment Setup

# Create virtual environment
python -m venv venv
source venv/bin/activate  # Linux/Mac
# venv\Scripts\activate   # Windows

# Install dependencies
pip install -r requirements.txt

2. Data Preparation

The project includes sample data:

  • V_228.csv: Speed data from 228 traffic detectors
  • W_228.csv: Adjacency matrix of traffic network
  • PeMSD7_Full.zip: Complete PeMS dataset

For custom data, ensure:

  • Flow data dimensions: [time_steps, nodes, features]
  • Adjacency matrix dimensions: [nodes, nodes]
  • Data format matches sample files

3. Model Training

Basic Training

# Train with default configuration
python src/main.py

Advanced Training Options

# Specify configuration file
python src/main.py --config config/custom_config.yaml

# Specify device
python src/main.py --device cuda

# Specify training/validation/test days
python src/main.py --n_train 31 --n_val 9 --n_test 4

4. Model Evaluation

After training, results will automatically display:

  • MAPE: Mean Absolute Percentage Error
  • MAE: Mean Absolute Error
  • RMSE: Root Mean Square Error
  • Training Time: Total training duration

βš™οΈ Configuration Guide

Configuration File (config/config.yaml)

# Dataset Configuration
dataset:
  name: "PEMS"
  node_num: 228
  data_path: "../data/V_228.csv"
  adj_path: "../data/W_228.csv"

# Training Configuration
training:
  epochs: 50
  batch_size:
    train: 32
    val: 32
  learning_rate: 0.002
  device: "mps"

# Data Configuration
data:
  n_train: 31      # Training days
  n_val: 9         # Validation days
  n_test: 4        # Test days
  n_history: 12    # Historical time steps
  n_prediction: 12 # Prediction time steps
  input_dim: 1     # Input feature dimension
  day_slot: 288    # Daily time slots

# Model Architecture
model:
  temporal_module: "Transformer"
  spatial_module: "GAT"
  dropout: 0.5
  
  upper_temporal:
    hidden_dim: 512
    num_layers: 3
    features: 512
    heads: 8
    
  spatial:
    in_features: 512
    out_features: 64
    heads: 8
    
  lower_temporal:
    features: 64
    hidden_dim: 64
    num_layers: 3

πŸ“Š Data Format

Input Data

  • Flow Data: CSV format, each row represents a time step, each column represents a detector
  • Adjacency Matrix: CSV format, represents connection relationships and weights in traffic network

Output Format

Model outputs include:

  • Traffic flow predictions for next 12 time steps
  • Individual predictions for each detector
  • Confidence intervals (optional)

πŸ”§ Custom Usage

1. Custom Dataset

from src.data.dataset import load_dataset

# Load custom data
dataset = load_dataset(
    data_path="path/to/your/data.csv",
    data_split=(train_days, val_days, test_days),
    node_num=num_sensors,
    seq_len=history_steps + prediction_steps,
    day_slot=time_slots_per_day,
    normalize=True
)

2. Model Fine-tuning

from src.models.sast_gnn import SASTGNN
from src.utils.config import Config

config = Config()
config.model.temporal_module = "LSTM"  # Change to LSTM
config.model.spatial_module = "GAT"
config.training.epochs = 100

model = SASTGNN(config)

3. Predict New Data

import torch
from src.models.sast_gnn import SASTGNN
from src.utils.config import Config

# Load trained model
config = Config()
model = SASTGNN(config)
model.load_state_dict(torch.load('models/sast_gnn_final.pth')['model_state_dict'])

# Prepare input data
input_data = ...  # [batch, history, nodes, features]
prediction = model(input_data, target_data, adjacency_matrix)

πŸŽ›οΈ Advanced Features

1. Multi-GPU Training

# Auto-select available GPU
python src/main.py --device auto

# Specify multiple GPUs
python src/main.py --device cuda:0,1,2,3

2. Model Checkpoints

# Save checkpoints
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, 'checkpoints/checkpoint_epoch_{}.pth'.format(epoch))

3. Hyperparameter Tuning

# Use optuna for hyperparameter optimization
python scripts/hyperparameter_tuning.py

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

About

An implementation of SAST-GNN

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages