diff --git a/terratorch/datamodules/fire_scars.py b/terratorch/datamodules/fire_scars.py index 39038cae..9419b9cf 100644 --- a/terratorch/datamodules/fire_scars.py +++ b/terratorch/datamodules/fire_scars.py @@ -15,22 +15,42 @@ from terratorch.datamodules.utils import wrap_in_compose_is_list from terratorch.datasets import FireScarsHLS, FireScarsNonGeo, FireScarsSegmentationMask -MEANS = { - "BLUE": 0.033349706741586264, - "GREEN": 0.05701185520536176, - "RED": 0.05889748132001316, - "NIR_NARROW": 0.2323245113436119, - "SWIR_1": 0.1972854853760658, - "SWIR_2": 0.11944914225186566, +MEANS_PER_VERSION = { + '1': { + "BLUE": 0.0535, + "GREEN": 0.0788, + "RED": 0.0963, + "NIR_NARROW": 0.2119, + "SWIR_1": 0.2360, + "SWIR_2": 0.1731, + }, + '2': { + "BLUE": 0.0535, + "GREEN": 0.0788, + "RED": 0.0963, + "NIR_NARROW": 0.2119, + "SWIR_1": 0.2360, + "SWIR_2": 0.1731, + } } -STDS = { - "BLUE": 0.02269135568823774, - "GREEN": 0.026807560223070237, - "RED": 0.04004109844362779, - "NIR_NARROW": 0.07791732423672691, - "SWIR_1": 0.08708738838140137, - "SWIR_2": 0.07241979477437814, +STDS_PER_VERSION = { + '1': { + "BLUE": 0.0308, + "GREEN": 0.0378, + "RED": 0.0550, + "NIR_NARROW": 0.0707, + "SWIR_1": 0.0919, + "SWIR_2": 0.0841, + }, + '2': { + "BLUE": 0.0308, + "GREEN": 0.0378, + "RED": 0.0550, + "NIR_NARROW": 0.0707, + "SWIR_1": 0.0919, + "SWIR_2": 0.0841, + } } @@ -40,6 +60,7 @@ class FireScarsNonGeoDataModule(NonGeoDataModule): def __init__( self, data_root: str, + version: str = '2', batch_size: int = 4, num_workers: int = 0, bands: Sequence[str] = FireScarsNonGeo.all_band_names, @@ -54,14 +75,16 @@ def __init__( ) -> None: super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs) self.data_root = data_root - - means = [MEANS[b] for b in bands] - stds = [STDS[b] for b in bands] + means = MEANS_PER_VERSION[version] + stds = STDS_PER_VERSION[version] + self.means = [means[b] for b in bands] + self.stds = [stds[b] for b in bands] + self.version = version self.bands = bands self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) - self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"]) + self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) self.drop_last = drop_last self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace @@ -71,6 +94,7 @@ def setup(self, stage: str) -> None: if stage in ["fit"]: self.train_dataset = self.dataset_class( split="train", + version=self.version, data_root=self.data_root, transform=self.train_transform, bands=self.bands, @@ -81,6 +105,7 @@ def setup(self, stage: str) -> None: if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( split="val", + version=self.version, data_root=self.data_root, transform=self.val_transform, bands=self.bands, @@ -90,7 +115,8 @@ def setup(self, stage: str) -> None: ) if stage in ["test"]: self.test_dataset = self.dataset_class( - split="val", + split="test", + version=self.version, data_root=self.data_root, transform=self.test_transform, bands=self.bands, diff --git a/terratorch/datamodules/multi_temporal_crop_classification.py b/terratorch/datamodules/multi_temporal_crop_classification.py index 4957e088..c291a911 100644 --- a/terratorch/datamodules/multi_temporal_crop_classification.py +++ b/terratorch/datamodules/multi_temporal_crop_classification.py @@ -10,22 +10,42 @@ from terratorch.datamodules.utils import wrap_in_compose_is_list from terratorch.datasets import MultiTemporalCropClassification -MEANS = { - "BLUE": 494.905781, - "GREEN": 815.239594, - "RED": 924.335066, - "NIR_NARROW": 2968.881459, - "SWIR_1": 2634.621962, - "SWIR_2": 1739.579917, +MEANS_PER_VERSION = { + '1': { + "BLUE": 830.5397, + "GREEN": 2427.1667, + "RED": 760.6795, + "NIR_NARROW": 2575.2020, + "SWIR_1": 649.9128, + "SWIR_2": 2344.4357, + }, + '2': { + "BLUE": 829.5907, + "GREEN": 2437.3473, + "RED": 748.6308, + "NIR_NARROW": 2568.9369, + "SWIR_1": 638.9926, + "SWIR_2": 2336.4087, + } } -STDS = { - "BLUE": 284.925432, - "GREEN": 357.84876, - "RED": 575.566823, - "NIR_NARROW": 896.601013, - "SWIR_1": 951.900334, - "SWIR_2": 921.407808, +STDS_PER_VERSION = { + '1': { + "BLUE": 447.9155, + "GREEN": 910.8289, + "RED": 490.9398, + "NIR_NARROW": 1142.5207, + "SWIR_1": 430.9440, + "SWIR_2": 1094.0881, + }, + '2': { + "BLUE": 447.1192, + "GREEN": 913.5633, + "RED": 480.5570, + "NIR_NARROW": 1140.6160, + "SWIR_1": 418.6212, + "SWIR_2": 1091.6073, + } } @@ -35,6 +55,7 @@ class MultiTemporalCropClassificationDataModule(NonGeoDataModule): def __init__( self, data_root: str, + version: str = '2', batch_size: int = 4, num_workers: int = 0, bands: Sequence[str] = MultiTemporalCropClassification.all_band_names, @@ -51,9 +72,11 @@ def __init__( ) -> None: super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs) self.data_root = data_root - - self.means = [MEANS[b] for b in bands] - self.stds = [STDS[b] for b in bands] + means = MEANS_PER_VERSION[version] + stds = STDS_PER_VERSION[version] + self.means = [means[b] for b in bands] + self.stds = [stds[b] for b in bands] + self.version = version self.bands = bands self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) @@ -70,6 +93,7 @@ def setup(self, stage: str) -> None: if stage in ["fit"]: self.train_dataset = self.dataset_class( split="train", + version=self.version, data_root=self.data_root, transform=self.train_transform, bands=self.bands, @@ -82,6 +106,7 @@ def setup(self, stage: str) -> None: if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( split="val", + version=self.version, data_root=self.data_root, transform=self.val_transform, bands=self.bands, @@ -93,7 +118,8 @@ def setup(self, stage: str) -> None: ) if stage in ["test"]: self.test_dataset = self.dataset_class( - split="val", + split="test", + version=self.version, data_root=self.data_root, transform=self.test_transform, bands=self.bands, diff --git a/terratorch/datasets/fire_scars.py b/terratorch/datasets/fire_scars.py index f5b65516..1a987763 100644 --- a/terratorch/datasets/fire_scars.py +++ b/terratorch/datasets/fire_scars.py @@ -4,9 +4,8 @@ import glob import os import re -from collections.abc import Sequence from pathlib import Path -from typing import Any +from typing import Any, Sequence import albumentations as A import matplotlib as mpl @@ -20,11 +19,22 @@ from torchgeo.datasets import NonGeoDataset, RasterDataset from xarray import DataArray -from terratorch.datasets.utils import clip_image_percentile, default_transform, validate_bands +from terratorch.datasets.utils import clip_image_percentile, default_transform, filter_valid_files, validate_bands class FireScarsNonGeo(NonGeoDataset): - """NonGeo dataset implementation for fire scars.""" + """NonGeo dataset implementation for fire scars. + + If using the version 2 dataset, we use the version 2 train/val/test splits from the dataset. + If using the version 1 dataset, we use the version 1 train/val splits from the dataset. + """ + versions = ('1', '2') + + splits_per_version = { + '1': {"train": "train", "val": "val", "test": "val"}, + '2': {"train": "train", "val": "val", "test": "test"}, + } + all_band_names = ( "BLUE", "GREEN", @@ -39,11 +49,11 @@ class FireScarsNonGeo(NonGeoDataset): BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} num_classes = 2 - splits = {"train": "training", "val": "validation"} # Only train and val splits available def __init__( self, data_root: str, + version: str = '2', split: str = "train", bands: Sequence[str] = BAND_SETS["all"], transform: A.Compose | None = None, @@ -66,20 +76,39 @@ def __init__( use_metadata (bool): whether to return metadata info (time and location). """ super().__init__() - if split not in self.splits: - msg = f"Incorrect split '{split}', please choose one of {self.splits}." + if version not in self.versions: + msg = f"Incorrect version '{version}', please choose one of {self.versions}." + raise ValueError(msg) + splits = self.splits_per_version[version] + if split not in splits: + msg = f"Incorrect split '{split}', please choose one of {list(splits.keys())}." raise ValueError(msg) - split_name = self.splits[split] - self.split = split + self.split = splits[split] validate_bands(bands, self.all_band_names) self.bands = bands self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands]) self.data_root = Path(data_root) - input_dir = self.data_root / split_name - self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif"))) - self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif"))) + self.image_files = sorted(glob.glob(os.path.join(self.data_root, "*_merged.tif"))) + self.segmentation_mask_files = sorted(glob.glob(os.path.join(self.data_root, "*.mask.tif"))) + + split_file = self.data_root / f"{self.split}_v{version}_data.txt" + with open(split_file) as f: + split = f.readlines() + valid_files = {rf"{substring.strip()}" for substring in split} + self.image_files = filter_valid_files( + self.image_files, + valid_files=valid_files, + ignore_extensions=True, + allow_substring=True, + ) + self.segmentation_mask_files = filter_valid_files( + self.segmentation_mask_files, + valid_files=valid_files, + ignore_extensions=True, + allow_substring=True, + ) self.use_metadata = use_metadata self.no_data_replace = no_data_replace diff --git a/terratorch/datasets/multi_temporal_crop_classification.py b/terratorch/datasets/multi_temporal_crop_classification.py index 709800d4..262edced 100644 --- a/terratorch/datasets/multi_temporal_crop_classification.py +++ b/terratorch/datasets/multi_temporal_crop_classification.py @@ -22,8 +22,17 @@ class MultiTemporalCropClassification(NonGeoDataset): - """NonGeo dataset implementation for multi-temporal crop classification.""" + """NonGeo dataset implementation for multi-temporal crop classification. + If using the version 2 dataset, we use the version 2 train/val/test splits from the dataset. + If using the version 1 dataset, we use the version 1 train/val splits from the dataset. + """ + versions = ('1', '2') + + splits_per_version = { + '1': {"train": "train", "val": "val", "test": "val"}, + '2': {"train": "train", "val": "val", "test": "test"}, + } all_band_names = ( "BLUE", "GREEN", @@ -55,7 +64,6 @@ class MultiTemporalCropClassification(NonGeoDataset): num_classes = 13 time_steps = 3 - splits = {"train": "training", "val": "validation"} # Only train and val splits available metadata_file_name = "chip_df_final.csv" col_name = "chip_id" date_columns = ["first_img_date", "middle_img_date", "last_img_date"] @@ -63,6 +71,7 @@ class MultiTemporalCropClassification(NonGeoDataset): def __init__( self, data_root: str, + version: str = '2', split: str = "train", bands: Sequence[str] = BAND_SETS["all"], transform: A.Compose | None = None, @@ -92,21 +101,23 @@ def __init__( use_metadata (bool): whether to return metadata info (time and location). """ super().__init__() - if split not in self.splits: - msg = f"Incorrect split '{split}', please choose one of {self.splits}." + if version not in self.versions: + msg = f"Incorrect version '{version}', please choose one of {self.versions}." + raise ValueError(msg) + splits = self.splits_per_version[version] + if split not in splits: + msg = f"Incorrect split '{split}', please choose one of {list(splits.keys())}." raise ValueError(msg) - split_name = self.splits[split] - self.split = split + self.split = splits[split] validate_bands(bands, self.all_band_names) self.bands = bands self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands]) self.data_root = Path(data_root) - data_dir = self.data_root / f"{split_name}_chips" - self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif"))) - self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif"))) - split_file = data_dir / f"{split_name}_data.txt" + self.image_files = sorted(glob.glob(os.path.join(self.data_root, "*_merged.tif"))) + self.segmentation_mask_files = sorted(glob.glob(os.path.join(self.data_root, "*.mask.tif"))) + split_file = self.data_root / f"{self.split}_v{version}_data.txt" with open(split_file) as f: split = f.readlines() @@ -235,7 +246,9 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure raise ValueError(msg) images = sample["image"] - images = images[rgb_indices, ...] # Shape: (T, 3, H, W) + if not self.expand_temporal_dimension: + images = rearrange(images, "(channels time) h w -> channels time h w", channels=len(self.bands)) + images = images[rgb_indices, ...] # Shape: (3, T, H, W) processed_images = [] for t in range(self.time_steps): @@ -247,7 +260,10 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure mask = sample["mask"].numpy() if "prediction" in sample: + prediction = sample["prediction"] num_images += 1 + else: + prediction = None fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed") ax[0].axis("off") @@ -261,11 +277,9 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure ax[self.time_steps + 1].title.set_text("Ground Truth Mask") ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm) - if "prediction" in sample: - prediction = sample["prediction"] - ax[self.time_steps + 1].axis("off") - ax[self.time_steps+2].title.set_text("Predicted Mask") - ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm) + if prediction is not None: + ax[self.time_steps + 2].title.set_text("Predicted Mask") + ax[self.time_steps + 2].imshow(prediction, cmap="jet", norm=norm) cmap = plt.get_cmap("jet") legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 229235f5..f059c17e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -84,29 +84,53 @@ def mpv4ger_data_root(tmp_path): @pytest.fixture(scope="function") def fire_scars_data_root(tmp_path): data_root = tmp_path / "fire_scars" - split = "train" - split_dir = data_root / FireScarsNonGeo.splits[split] - split_dir.mkdir(parents=True, exist_ok=True) + data_root.mkdir(parents=True, exist_ok=True) + + chip_ids = [] for i in range(5): random_seq = "".join(np.random.choice(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"), 5)) year = 2021 julian_day = i + 1 date = f"{year}{julian_day:03d}" - image_filename = f"subsetted_512x512_HLS.S30.T{random_seq}.{date}.v1.4_merged.tif" - mask_filename = f"subsetted_512x512_HLS.S30.T{random_seq}.{date}.v1.4.mask.tif" - image_path = split_dir / image_filename - mask_path = split_dir / mask_filename + base_name = f"subsetted_512x512_HLS.S30.T{random_seq}.{date}.v1.4" + image_filename = f"{base_name}_merged.tif" + mask_filename = f"{base_name}.mask.tif" + image_path = data_root / image_filename + mask_path = data_root / mask_filename create_dummy_tiff(image_path) create_dummy_tiff(mask_path, count=1, dtype="uint8") - image_files = list(split_dir.glob("*_merged.tif")) - mask_files = list(split_dir.glob("*.mask.tif")) + chip_ids.append(base_name) + + train_ids = chip_ids[:3] + val_ids = chip_ids[3:4] + test_ids = chip_ids[4:] + + with open(data_root / "train_v2_data.txt", "w") as f: + f.write("\n".join(train_ids)) + + with open(data_root / "val_v2_data.txt", "w") as f: + f.write("\n".join(val_ids)) + + with open(data_root / "test_v2_data.txt", "w") as f: + f.write("\n".join(test_ids)) + + image_files = list(data_root.glob("*_merged.tif")) + mask_files = list(data_root.glob("*.mask.tif")) assert len(image_files) == 5, f"Expected 5 image files, but found {len(image_files)}" assert len(mask_files) == 5, f"Expected 5 mask files, but found {len(mask_files)}" + split_files = ["train_v2_data.txt", "val_v2_data.txt", "test_v2_data.txt"] + for split_file in split_files: + file_path = data_root / split_file + assert file_path.exists(), f"Expected split file {split_file} to exist." + with open(file_path, "r") as f: + lines = f.read().splitlines() + assert len(lines) > 0, f"Split file {split_file} should not be empty." + return str(data_root) @pytest.fixture(scope="function") @@ -507,32 +531,30 @@ def chesapeake_data_root(tmp_path): @pytest.fixture(scope="function") def crop_classification_data_root(tmp_path): data_root = tmp_path / "crop_classification" - training_dir = data_root / "training_chips" - validation_dir = data_root / "validation_chips" - - training_dir.mkdir(parents=True, exist_ok=True) - validation_dir.mkdir(parents=True, exist_ok=True) - - for directory in [training_dir, validation_dir]: - for i in range(2): - filename = f"chip_{i}_merged.tif" - label_filename = f"chip_{i}.mask.tif" - img_data = DataArray(np.random.rand(18, 64, 64).astype(np.float32), dims=["band", "y", "x"]) - mask_data = DataArray(np.random.randint(0, 13, size=(1, 64, 64), dtype=np.uint8), dims=["band", "y", "x"]) - - img_data.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True) - mask_data.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True) - img_data.rio.write_crs("EPSG:4326", inplace=True) - mask_data.rio.write_crs("EPSG:4326", inplace=True) - image_path = directory / filename - mask_path = directory / label_filename - img_data.rio.to_raster(str(image_path)) - mask_data.rio.to_raster(str(mask_path)) - - with open(training_dir / "training_data.txt", "w") as f: + + data_root.mkdir(parents=True, exist_ok=True) + + for i in range(2): + filename = f"chip_{i}_merged.tif" + label_filename = f"chip_{i}.mask.tif" + img_data = DataArray(np.random.rand(18, 64, 64).astype(np.float32), dims=["band", "y", "x"]) + mask_data = DataArray(np.random.randint(0, 13, size=(1, 64, 64), dtype=np.uint8), dims=["band", "y", "x"]) + img_data = img_data.rio.set_spatial_dims(x_dim="x", y_dim="y") + mask_data = mask_data.rio.set_spatial_dims(x_dim="x", y_dim="y") + img_data = img_data.rio.write_crs("EPSG:4326") + mask_data = mask_data.rio.write_crs("EPSG:4326") + image_path = data_root / filename + mask_path = data_root / label_filename + img_data.rio.to_raster(str(image_path)) + mask_data.rio.to_raster(str(mask_path)) + + with open(data_root / "train_v2_data.txt", "w") as f: + f.write("\n".join([f"chip_{i}" for i in range(2)])) + + with open(data_root / "val_v2_data.txt", "w") as f: f.write("\n".join([f"chip_{i}" for i in range(2)])) - with open(validation_dir / "validation_data.txt", "w") as f: + with open(data_root / "test_v2_data.txt", "w") as f: f.write("\n".join([f"chip_{i}" for i in range(2)])) metadata = pd.DataFrame({ @@ -632,7 +654,7 @@ def test_plot(self, eurosat_data_root): class TestFireScarsNonGeo: def test_dataset_length(self, fire_scars_data_root): dataset = FireScarsNonGeo(data_root=fire_scars_data_root, split="train") - expected_length = 5 + expected_length = 3 actual_length = len(dataset) assert actual_length == expected_length, f"Expected {expected_length}, but got {actual_length}"