A PyTorch Lightning-based framework for geospatial deep learning with multi-sensor Earth observation data.
Geo Deep Learning (GDL) is a modular framework designed to support a wide range of geospatial deep learning tasks such as semantic segmentation, object detection, and regression. Built on PyTorch Lightning, it provides efficient training pipelines for multi-sensor data.
- Multi-sensor Support: Handle multiple Earth observation sensors simultaneously.
- Modular Architecture: Encoder-neck-decoder pattern with interchangeable components.
- WebDataset Integration: Efficient large-scale data loading and processing.
- Multiple Model Types: UNet++, SegFormer, DOFA (Dynamic-one-for-all Architecture).
- Distributed Training: Multi-GPU training with supported strategies.
- MLflow Logging: Comprehensive experiment tracking and model versioning.
- Flexible Data Pipeline: Support for CSV and WebDataset formats.
├── models/
│ ├── encoders/ # DOFA, MixTransformer backbones
│ ├── necks/ # Multi-level feature processing
│ ├── decoders/ # UperNet decoder implementation
│ └── heads/ # Segmentation heads (FCN, etc.)
├── datamodules/ # Lightning DataModules
├── datasets/ # WebDataset and CSV dataset implementations
├── tasks_with_models/ # Lightning modules for training
├── tools/ # Utilities, callbacks, visualization
└── samplers/ # Custom data sampling strategies
- Install uv package manager for your OS.
- Clone the repository:
git clone https://github.com/NRCan/geo-deep-learning.git
cd geo-deep-learning- Install dependencies:
For GPU training with CUDA 12.8:
uv sync --extra cu128For CPU-only training:
uv sync --extra cpuThis creates a virtual environment in .venv/ and installs all dependencies.
- Activate the environment:
# Linux/macOS
source .venv/bin/activate
# Windows
.venv\Scripts\activateOr use uv run to execute commands without manual activation:
uv run python geo_deep_learning/train.py fit --config configs/dofa_config_RGB.yamlNote: If you prefer to use conda or another environment manager, you can generate a requirements.txt file from the dependencies listed in pyproject.toml for manual installation.
Models are configured via YAML files in the configs/ directory:
model:
class_path: tasks_with_models.segmentation_dofa.SegmentationDOFA
init_args:
encoder: "dofa_base"
pretrained: true
image_size: [512, 512]
num_classes: 5
# ... other parameters
data:
class_path: datamodules.wds_datamodule.MultiSensorDataModule
init_args:
sensor_configs_path: "path/to/sensor_configs.yaml"
batch_size: 16
patch_size: [512, 512]
trainer:
max_epochs: 100
precision: 16-mixed
accelerator: gpu
devices: 1- Classic U-Net architecture with dense skip connections.
- Multiple encoder backbones (ResNet, EfficientNet, etc.).
- Available through segmentation-models-pytorch.
- Transformer-based architecture for semantic segmentation.
- Hierarchical feature representation (MixTransformer encoder).
- Multiple model sizes (B0-B5).
- DOFA Base: 768-dim embeddings, suitable for most tasks.
- DOFA Large: 1024-dim embeddings, higher capacity.
- Multi-scale feature extraction with UperNet decoder.
- Support for wavelength-specific processing.
- Sensor Mixing: Combine data from multiple sensors during training.
- WebDataset Format: Efficient sharded data storage and loading.
- WebDataset: Sharded tar files with metadata.
- CSV: Traditional CSV with file paths and labels.
- Multi-sensor: YAML configuration for sensor-specific settings.
- Large-scale training: Distributed training strategies enabled with pytorch lightning.
- Mixed Precision Training: 16-bit mixed precision for faster training.
- Gradient Clipping: Configurable gradient clipping for stability.
- Early Stopping: Automatic training termination based on validation metrics.
- Model Checkpointing: Saves best models based on validation performance.
- MLflow Integration: Experiment tracking, metrics logging, and model registry.
- Visualization Callbacks: Built-in prediction visualization during training.
- Learning Rate Scheduling: Cosine annealing, step decay, and more.
- Follows PEP 8 with 88-character line limit
- Uses Ruff for linting and formatting
- Type hints for all function signatures
- Comprehensive docstrings
# Lint code
ruff check .
# Format code
ruff format .- Fork the repository
- Create a feature branch
- Make your changes
- Add tests if applicable
- Submit a pull request
For issues and questions:
- Open an issue on GitHub