Skip to content

Commit 713040d

Browse files
committed
Merge remote-tracking branch 'origin/main'
1 parent cf954be commit 713040d

File tree

14 files changed

+191
-75
lines changed

14 files changed

+191
-75
lines changed

.github/workflows/build-on-tag.yml

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
2+
name: Build Docker Image on Tag (CUDA 12)
3+
14
on:
2-
create
5+
push:
6+
tags:
7+
- '*'
38

49
env:
510
REGISTRY: docker.io
@@ -28,16 +33,16 @@ jobs:
2833
images: ${{ env.REGISTRY }}/${{ secrets.DOCKER_REPO }}
2934

3035
- name: Build the docker image
31-
run: docker build . --file Dockerfile --tag gdl-cuda11:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} --build-arg GIT_TAG=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
36+
run: docker build . --file Dockerfile --tag gdl-cuda12:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} --build-arg GIT_TAG=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
3237

3338
- name: Tag the docker image
34-
run: docker tag gdl-cuda11:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} ${{ secrets.DOCKER_REPO }}/gdl-cuda11:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
39+
run: docker tag gdl-cuda12:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} ${{ secrets.DOCKER_REPO }}/gdl-cuda12:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
3540

3641
- name: Push the docker image
37-
run: docker push ${{ secrets.DOCKER_REPO }}/gdl-cuda11:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
42+
run: docker push ${{ secrets.DOCKER_REPO }}/gdl-cuda12:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }}
3843

3944
- name: Tag the docker image to latest
40-
run: docker tag gdl-cuda11:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} ${{ secrets.DOCKER_REPO }}/gdl-cuda11:latest
45+
run: docker tag gdl-cuda12:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} ${{ secrets.DOCKER_REPO }}/gdl-cuda12:latest
4146

4247
- name: Push the docker image (latest tag)
43-
run: docker push ${{ secrets.DOCKER_REPO }}/gdl-cuda11:latest
48+
run: docker push ${{ secrets.DOCKER_REPO }}/gdl-cuda12:latest

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
*.idea**
33
*.vscode**
44

5+
# Distribution / packaging
6+
*.egg-info/
7+
58
# Specific folders name
69
waterloo_subset_512/
710
mlruns/

Dockerfile

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# syntax=docker/dockerfile:1
2+
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
3+
4+
RUN apt-get update && apt-get install -y curl bzip2 && \
5+
curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | \
6+
tar -xvj -C /usr/local/bin --strip-components=1 bin/micromamba && \
7+
rm -rf /var/lib/apt/lists/*
8+
9+
ENV MAMBA_DOCKERFILE_ACTIVATE=1 \
10+
CONDA_ENV_NAME=geo-dl \
11+
MAMBA_ROOT_PREFIX=/opt/conda \
12+
PATH="/opt/conda/envs/geo-dl/bin:$PATH"
13+
14+
WORKDIR /tmp
15+
COPY requirements.txt pyproject.toml ./
16+
17+
RUN micromamba create -y -n $CONDA_ENV_NAME -c conda-forge python=3.10 pip && \
18+
micromamba run -n $CONDA_ENV_NAME pip install --no-cache-dir -r requirements.txt && \
19+
find $MAMBA_ROOT_PREFIX/envs/$CONDA_ENV_NAME -name "*.pyc" -delete 2>/dev/null || true && \
20+
find $MAMBA_ROOT_PREFIX/envs/$CONDA_ENV_NAME -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null || true && \
21+
micromamba clean -a -y
22+
23+
RUN useradd -m -u 1000 gdl_user && mkdir -p /app && chown -R gdl_user /app
24+
USER gdl_user
25+
26+
WORKDIR /app
27+
COPY --chown=gdl_user:gdl_user . /app
28+
29+
ENTRYPOINT ["python"]
30+
CMD ["-m", "geo_deep_learning.train"]

README.md

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ A PyTorch Lightning-based framework for geospatial deep learning with multi-sens
44

55
## Overview
66

7-
Geo Deep Learning (GDL) is a modular framework designed for semantic segmentation of geospatial imagery using state-of-the-art deep learning models. Built on PyTorch Lightning, it provides efficient training pipelines for multi-sensor data with WebDataset support.
7+
Geo Deep Learning (GDL) is a modular framework designed to support a wide range of geospatial deep learning tasks such as semantic segmentation, object detection, and regression.
8+
Built on PyTorch Lightning, it provides efficient training pipelines for multi-sensor data.
89

910
## Features
1011

11-
- **Multi-sensor Support**: Handle multiple Earth observation sensors simultaneously
12-
- **Modular Architecture**: Encoder-neck-decoder pattern with interchangeable components
13-
- **WebDataset Integration**: Efficient large-scale data loading and processing
14-
- **Multiple Model Types**: UNet++, SegFormer, DOFA (Dynamic-one-for-all Architecture)
15-
- **Distributed Training**: Multi-GPU training with DDP strategy
16-
- **MLflow Logging**: Comprehensive experiment tracking and model versioning
17-
- **Flexible Data Pipeline**: Support for CSV and WebDataset formats
12+
- **Multi-sensor Support**: Handle multiple Earth observation sensors simultaneously.
13+
- **Modular Architecture**: Encoder-neck-decoder pattern with interchangeable components.
14+
- **WebDataset Integration**: Efficient large-scale data loading and processing.
15+
- **Multiple Model Types**: UNet++, SegFormer, DOFA (Dynamic-one-for-all Architecture).
16+
- **Distributed Training**: Multi-GPU training with supported strategies.
17+
- **MLflow Logging**: Comprehensive experiment tracking and model versioning.
18+
- **Flexible Data Pipeline**: Support for CSV and WebDataset formats.
1819

1920
## Architecture
2021

@@ -31,23 +32,47 @@ Geo Deep Learning (GDL) is a modular framework designed for semantic segmentatio
3132
└── samplers/ # Custom data sampling strategies
3233
```
3334

35+
## Requirements
36+
- Install [uv](https://docs.astral.sh/uv/) package manager for your OS.
37+
3438
## Quick Start
3539

40+
1. **Clone the repository:**
3641
```bash
37-
git clone <repository-url>
42+
git clone https://github.com/NRCan/geo-deep-learning.git
3843
cd geo-deep-learning
3944
```
45+
2. **Install dependencies:**
4046

41-
### Training
47+
For **GPU training** with CUDA 12.8:
48+
```bash
49+
uv sync --extra cu128
50+
```
4251

52+
For **CPU-only** training:
4353
```bash
44-
# Single GPU training
45-
python geo_deep_learning/train.py fit --config configs/dofa_config_RGB.yaml
54+
uv sync --extra cpu
4655
```
56+
This creates a virtual environment in `.venv/` and installs all dependencies.
57+
58+
3. **Activate the environment:**
59+
```bash
60+
# Linux/macOS
61+
source .venv/bin/activate
62+
63+
# Windows
64+
.venv\Scripts\activate
65+
```
66+
67+
Or use `uv run` to execute commands without manual activation:
68+
```bash
69+
uv run python geo_deep_learning/train.py fit --config configs/dofa_config_RGB.yaml
70+
```
71+
**Note:** *If you prefer to use conda or another environment manager, you can generate a `requirements.txt` file from the dependencies listed in `pyproject.toml` for manual installation.*
4772

4873
### Configuration
4974

50-
Models are configured via YAML files in `configs/`:
75+
Models are configured via YAML files in the `configs/` directory:
5176

5277
```yaml
5378
model:
@@ -65,54 +90,53 @@ data:
6590
sensor_configs_path: "path/to/sensor_configs.yaml"
6691
batch_size: 16
6792
patch_size: [512, 512]
93+
94+
trainer:
95+
max_epochs: 100
96+
precision: 16-mixed
97+
accelerator: gpu
98+
devices: 1
6899
```
69100
70101
## Supported Models
71102
72-
### DOFA (Domain-Oriented Foundation Architecture)
73-
- **DOFA Base**: 768-dim embeddings, suitable for most tasks
74-
- **DOFA Large**: 1024-dim embeddings, higher capacity
75-
- Multi-scale feature extraction with UperNet decoder
76-
- Support for wavelength-specific processing
77-
78103
### UNet++
79-
- Classic U-Net architecture with dense skip connections
80-
- Multiple encoder backbones (ResNet, EfficientNet, etc.)
81-
- Optimized for medical and satellite imagery
104+
- Classic U-Net architecture with dense skip connections.
105+
- Multiple encoder backbones (ResNet, EfficientNet, etc.).
106+
- Available through segmentation-models-pytorch.
82107
83108
### SegFormer
84-
- Transformer-based architecture for semantic segmentation
85-
- Hierarchical feature representation
86-
- Efficient attention mechanisms
109+
- Transformer-based architecture for semantic segmentation.
110+
- Hierarchical feature representation (MixTransformer encoder).
111+
- Multiple model sizes (B0-B5).
112+
113+
### DOFA (Dynamic One-For-All foundation model)
114+
- **DOFA Base**: 768-dim embeddings, suitable for most tasks.
115+
- **DOFA Large**: 1024-dim embeddings, higher capacity.
116+
- Multi-scale feature extraction with UperNet decoder.
117+
- Support for wavelength-specific processing.
118+
87119
88120
## Data Pipeline
89121
90122
### Multi-Sensor DataModule
91-
- **Sensor Mixing**: Combine data from multiple sensors during training
92-
- **WebDataset Format**: Efficient sharded data storage and loading
93-
- **Patch-based Processing**: Configurable patch sizes (default: 512x512)
94-
- **Data Augmentation**: Built-in augmentation pipeline
123+
- **Sensor Mixing**: Combine data from multiple sensors during training.
124+
- **WebDataset Format**: Efficient sharded data storage and loading.
95125
96126
### Supported Data Formats
97-
- **WebDataset**: Sharded tar files with metadata
98-
- **CSV**: Traditional CSV with file paths and labels
99-
- **Multi-sensor**: YAML configuration for sensor-specific settings
127+
- **WebDataset**: Sharded tar files with metadata.
128+
- **CSV**: Traditional CSV with file paths and labels.
129+
- **Multi-sensor**: YAML configuration for sensor-specific settings.
100130
101131
## Training Features
102-
103-
- **Mixed Precision**: 16-bit mixed precision training
104-
- **Gradient Clipping**: Configurable gradient clipping
105-
- **Early Stopping**: Automatic training termination
106-
- **Model Checkpointing**: Best model saving based on validation metrics
107-
- **Visualization**: Built-in prediction visualization callbacks
108-
109-
## Distributed Training
110-
111-
The framework supports multi-GPU training with:
112-
- DDP (Distributed Data Parallel) strategy
113-
- Automatic mixed precision
114-
- Synchronized batch normalization
115-
- Efficient NCCL communication
132+
- **Large-scale training**: Distributed training strategies enabled with pytorch lightning.
133+
- **Mixed Precision Training**: 16-bit mixed precision for faster training.
134+
- **Gradient Clipping**: Configurable gradient clipping for stability.
135+
- **Early Stopping**: Automatic training termination based on validation metrics.
136+
- **Model Checkpointing**: Saves best models based on validation performance.
137+
- **MLflow Integration**: Experiment tracking, metrics logging, and model registry.
138+
- **Visualization Callbacks**: Built-in prediction visualization during training.
139+
- **Learning Rate Scheduling**: Cosine annealing, step decay, and more.
116140
117141
## Development
118142
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Logging configuration."""
File renamed without changes.

geo_deep_learning/models/encoders/mix_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import torch
99
import torch.nn.functional as fn
10-
from models.segmentation.base import EncoderMixin
1110
from timm.layers import DropPath, to_2tuple, trunc_normal_
1211
from torch import Tensor, nn
1312
from torch.utils import model_zoo
1413

14+
from geo_deep_learning.models.segmentation.base import EncoderMixin
15+
1516

1617
class Mlp(nn.Module):
1718
"""MLP module."""

geo_deep_learning/models/segmentation/segformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
import torch
44
import torch.nn.functional as fn
5-
from models.decoders.segformer_mlp import Decoder
6-
from models.encoders.mix_transformer import DynamicMixTransformer, get_encoder
5+
6+
from geo_deep_learning.models.decoders.segformer_mlp import Decoder
7+
from geo_deep_learning.models.encoders.mix_transformer import (
8+
DynamicMixTransformer,
9+
get_encoder,
10+
)
711

812
from .base import BaseSegmentationModel
913

geo_deep_learning/tasks_with_models/segmentation_dofa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from torchmetrics.segmentation import MeanIoU
1717
from torchmetrics.wrappers import ClasswiseWrapper
1818

19+
from geo_deep_learning.models.segmentation.dofa import DOFASegmentationModel
20+
from geo_deep_learning.tools.visualization import visualize_prediction
1921
from geo_deep_learning.utils.models import load_weights_from_checkpoint
2022
from geo_deep_learning.utils.tensors import denormalization
21-
from models.segmentation.dofa import DOFASegmentationModel
22-
from tools.visualization import visualize_prediction
2323

2424
# Ignore warning about default grid_sample and affine_grid behavior triggered by kornia
2525
warnings.filterwarnings(

0 commit comments

Comments
 (0)