diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 821424a..72c99b7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,7 +21,7 @@ jobs: lfs: false - name: Set up Python 3.11 - uses: actions/setup-python@v4 + uses: actions/setup-python@v with: python-version: 3.11.13 pip-version: 24 diff --git a/olmoearth_projects/olmoearth_run/olmoearth_run.py b/olmoearth_projects/olmoearth_run/olmoearth_run.py index 6807b6a..ad9b337 100644 --- a/olmoearth_projects/olmoearth_run/olmoearth_run.py +++ b/olmoearth_projects/olmoearth_run/olmoearth_run.py @@ -173,6 +173,7 @@ def one_stage( fn = runner.run_inference elif stage == OlmoEarthRunStage.POSTPROCESS: fn = runner.postprocess + runner.inference_results_data_type = "RASTER" # Needs to be set before postprocessing else: assert False @@ -185,4 +186,5 @@ def one_stage( fn(partition_id) elif stage == OlmoEarthRunStage.COMBINE: + runner.inference_results_data_type = "RASTER" # Needs to be set before combining runner.combine(partitions) diff --git a/olmoearth_projects/projects/yemen_crop/dense_polygon_to_raster.py b/olmoearth_projects/projects/yemen_crop/dense_polygon_to_raster.py new file mode 100644 index 0000000..aaa28f4 --- /dev/null +++ b/olmoearth_projects/projects/yemen_crop/dense_polygon_to_raster.py @@ -0,0 +1,341 @@ +""" +Polygon to Raster Window Preparer. + +This module provides a window preparer that converts polygon/multipolygon annotations +to raster labels. It creates a grid around the entire area and then splits the grid into windows. +ALL polygon labels intersected by the window are assigned to that window. +""" + +from shapely.geometry import Polygon +from shapely.geometry.base import BaseGeometry +from typing import cast +import math +import numpy as np +from shapely.geometry import box +from olmoearth_run.runner.models.training.labeled_data import ( + AnnotationTask, + LabeledWindow, + RasterLabel, +) +from olmoearth_run.runner.tools.labeled_window_preparers.labeled_window_preparer import ( + RasterLabelsWindowPreparer, +) +from olmoearth_run.runner.tools.labeled_window_preparers.geometry_utils import ( + # compute_window_bounds, + create_raster_labels_from_annotations, + project_geometry_to_crs, +) +from olmoearth_run.runner.tools.labeled_window_preparers.rasterization_utils import ( + DEFAULT_NODATA_VALUE, +) +from rslearn.utils import STGeometry, get_utm_ups_crs + + +def grid_split_raster_labels( + full_window_bounds: tuple[int, int, int, int], # These are PIXEL coordinates + full_raster_labels: list[RasterLabel], + input_size: int, + nodata_value: int, + base_st_geometry: STGeometry, + task_id: str, +) -> list[LabeledWindow]: + # Extract pixel bounds + minx_px, miny_px, maxx_px, maxy_px = full_window_bounds + + # Dimensions in pixels + full_width_px = maxx_px - minx_px + full_height_px = maxy_px - miny_px + + assert full_width_px % input_size == 0, f"Full width must be divisible by input size got {full_width_px} and {input_size}" + assert full_height_px % input_size == 0, f"Full height must be divisible by input size got {full_height_px} and {input_size}" + assert len(full_raster_labels) == 1, "Only a single raster of labels is currently supported" + # Number of tiles (working in pixel space) + num_tiles_x = math.ceil(full_width_px / input_size) + num_tiles_y = math.ceil(full_height_px / input_size) + + labeled_windows = [] + + for tile_y in range(num_tiles_y): + for tile_x in range(num_tiles_x): + # Tile bounds in PIXEL coordinates + tile_minx_px = minx_px + (tile_x * input_size) + tile_miny_px = miny_px + (tile_y * input_size) + tile_maxx_px = tile_minx_px + input_size + tile_maxy_px = tile_miny_px + input_size + + # Array slicing (also in pixels, but relative to array origin) + # Array indices start at 0, so offset by the full bounds + array_x_start = tile_x * input_size + array_y_start = tile_y * input_size + array_x_end = array_x_start + input_size + array_y_end = array_y_start + input_size + # Slice raster labels + full_label = full_raster_labels[0] + tile_array = full_label.value[array_y_start:array_y_end, array_x_start:array_x_end] + if np.all(tile_array == nodata_value): + continue + tile_raster_labels = [RasterLabel(key=full_label.key, value=tile_array)] + + + # Create tile geometry in PIXEL space with same projection + tile_polygon = box(tile_minx_px, tile_miny_px, tile_maxx_px, tile_maxy_px) + tile_st_geometry = STGeometry( + base_st_geometry.projection, # Same projection (includes resolution) + tile_polygon, # Polygon in pixel coordinates + base_st_geometry.time_range, + ) + + tile_name = f"task_{task_id}_tile_{tile_y}_{tile_x}" + labeled_windows.append( + LabeledWindow(name=tile_name, st_geometry=tile_st_geometry, labels=tile_raster_labels) + ) + + return labeled_windows + + +def pad_raster_labels_to_input_size( + full_raster_labels: list[RasterLabel], + input_size: int, + nodata_value: int = DEFAULT_NODATA_VALUE, +) -> list[RasterLabel]: + """Pad raster labels to the next multiple of input_size.""" + padded_labels = [] + + for label in full_raster_labels: + height, width = label.value.shape + + # Calculate padded dimensions + padded_height = math.ceil(height / input_size) * input_size + padded_width = math.ceil(width / input_size) * input_size + + # Calculate padding amounts (pad on right and bottom) + pad_height = padded_height - height + pad_width = padded_width - width + + # Pad the array with nodata_value + padded_array = np.pad( + label.value, + ((0, pad_height), (0, pad_width)), + mode='constant', + constant_values=nodata_value + ) + + padded_labels.append(RasterLabel(key=label.key, value=padded_array)) + + return padded_labels + + +def pad_window_bounds_to_input_size( + full_window_bounds: tuple[int, int, int, int], + input_size: int, +) -> tuple[int, int, int, int]: + """Pad window bounds to the next multiple of input_size.""" + minx_px, miny_px, maxx_px, maxy_px = full_window_bounds + + # Calculate current dimensions + width = maxx_px - minx_px + height = maxy_px - miny_px + + # Calculate padded dimensions + padded_width = math.ceil(width / input_size) * input_size + padded_height = math.ceil(height / input_size) * input_size + + # Extend max bounds (keep min bounds fixed) + padded_maxx_px = minx_px + padded_width + padded_maxy_px = miny_px + padded_height + + return (minx_px, miny_px, padded_maxx_px, padded_maxy_px) + + +def pad_st_geometry_to_input_size( + base_st_geometry: STGeometry, + full_window_bounds: tuple[int, int, int, int], +) -> STGeometry: + """Update STGeometry to match padded window bounds.""" + minx, miny, maxx, maxy = full_window_bounds + + # Create new bounding box with padded dimensions + padded_polygon = box(minx, miny, maxx, maxy) + + return STGeometry( + base_st_geometry.projection, + padded_polygon, + base_st_geometry.time_range, + ) + + +def pad_raster_bounds_and_geometry( + full_raster_labels: list[RasterLabel], + base_st_geometry: STGeometry, + full_window_bounds: tuple[int, int, int, int], + input_size: int, + nodata_value: int = DEFAULT_NODATA_VALUE, +) -> tuple[list[RasterLabel], STGeometry, tuple[int, int, int, int]]: + """Pad raster labels, geometry, and bounds to the next multiple of input_size.""" + # Pad the full window bounds to the next multiple of the input size + padded_window_bounds = pad_window_bounds_to_input_size(full_window_bounds, input_size) + + # Pad the full raster labels to the next multiple of the input size + padded_raster_labels = pad_raster_labels_to_input_size(full_raster_labels, input_size, nodata_value) + + # Pad the base st geometry to the next multiple of the input size + padded_st_geometry = pad_st_geometry_to_input_size(base_st_geometry, padded_window_bounds) + + return padded_raster_labels, padded_st_geometry, padded_window_bounds + +# THIS Function doesn't assume the all positive nature of utm +def compute_window_bounds(window_geometry: STGeometry) -> tuple[int, int, int, int]: + """Compute integer bounds for the window from the window geometry. + + Args: + window_geometry: The window geometry + + Returns: + Tuple of (minx, miny, maxx, maxy) in world coordinates + """ + bounds = cast(BaseGeometry, window_geometry.shp).bounds + + + minx = math.floor((bounds[0])) + miny = math.floor((bounds[1])) + maxx = math.ceil((bounds[2])) + maxy = math.ceil((bounds[3])) + + return (minx, miny, maxx, maxy) + + +class DensePolygonToRasterWindowPreparer(RasterLabelsWindowPreparer): + """ + Window preparer that converts dense polygon/multipolygon annotations to raster labels. + + This preparer creates gridded windows around the entire annotation task area, + then splits the rasterized labels into smaller tiles. + + Key characteristics: + - Multiple windows per task (tiled grid) + - Each tile has aligned raster labels + - Labels are uint8 raster arrays + - Uses UTM projection for consistent resolution + """ + + def __init__( + self, + window_resolution: float = 10.0, + input_size: int = 16, + nodata_value: int = DEFAULT_NODATA_VALUE + ): + """ + Initialize the DensePolygonToRasterWindowPreparer. + + Args: + window_resolution: Resolution in meters per pixel (default: 10.0) + input_size: Size of each tile in pixels (default: 512x512) + nodata_value: Value to use for nodata pixels + """ + self.window_resolution = window_resolution + self.input_size = input_size + self.nodata_value = nodata_value + + + def prepare_labeled_windows( + self, annotation_task: AnnotationTask + ) -> list[LabeledWindow[list[RasterLabel]]]: + """ + Prepare labeled windows from polygon annotation tasks. + + This method creates one window per annotation task, using the task geometry + as the window boundary. It rasterizes all polygon annotations within the task + into a single uint8 raster label. + + Args: + annotation_task: Single AnnotationTask object containing task context and annotations + + Returns: + List containing one LabeledWindow object with raster labels, or empty list if no annotations + """ + if not annotation_task.annotations: + return [] + + # Calculate CRS based on task centroid + # First check what projection the task geometry is in + print(f"Task geometry projection: {annotation_task.task_st_geometry.projection}") + print(f"Task geometry projection CRS: {annotation_task.task_st_geometry.projection.crs}") + + task_centroid = cast( + BaseGeometry, annotation_task.task_st_geometry.shp + ).centroid + task_bounds = cast( + BaseGeometry, annotation_task.task_st_geometry.shp + ).bounds + print(f"Task bounds before projection: {task_bounds}") + + corner_coords = [ + (task_bounds[0], task_bounds[1]), # low left + (task_bounds[2], task_bounds[3]), # high right + (task_bounds[0], task_bounds[3]), # low right + (task_bounds[2], task_bounds[1]), # high left + ] + print(f"Corner coords: {corner_coords}") + corner_utm_crs = [get_utm_ups_crs(coord[0], coord[1]) for coord in corner_coords] + print(f"Corner utm crs: {corner_utm_crs}") + assert all(utm_crs == corner_utm_crs[0] for utm_crs in corner_utm_crs), "Corner utm crs are not all the same" + utm_crs = corner_utm_crs[0] + utm_crs = get_utm_ups_crs(task_centroid.x, task_centroid.y) + # Extract the task geometry + task_geom = annotation_task.task_st_geometry.shp + if not isinstance( + task_geom, (Polygon, BaseGeometry) + ) or task_geom.geom_type not in ["Polygon", "MultiPolygon"]: + raise ValueError( + f"Expected Polygon or MultiPolygon for task, got {type(task_geom)} with geom_type {getattr(task_geom, 'geom_type', 'unknown')}" + ) + + # Convert to appropriate projection if needed + projected_geometry = project_geometry_to_crs( + task_geom, self.window_resolution, utm_crs + ) + + # Create the window geometry + window_st_geometry = STGeometry( + projected_geometry.projection, + projected_geometry.shp, + annotation_task.task_st_geometry.time_range, + ) + + # Create the full raster label by rasterizing all polygon annotations + window_bounds = compute_window_bounds(projected_geometry) + full_raster_labels = create_raster_labels_from_annotations( + annotations=annotation_task.annotations, + window_bounds=window_bounds, + window_resolution=self.window_resolution, + crs=utm_crs, + nodata_value=self.nodata_value, + ) + print(f"Full raster labels shape: {full_raster_labels[0].value.shape}") + print(f"Window bounds before padding: {window_bounds}") + + # Pad to the next multiple of input_size + padded_raster_labels, padded_st_geometry, padded_window_bounds = pad_raster_bounds_and_geometry( + full_raster_labels=full_raster_labels, + base_st_geometry=window_st_geometry, + full_window_bounds=window_bounds, + input_size=self.input_size, + nodata_value=self.nodata_value, + ) + + print(f"Padded raster shape: {padded_raster_labels[0].value.shape}") + print(f"Padded window bounds: {padded_window_bounds}") + + # Split the full window into tiles + labeled_windows = grid_split_raster_labels( + full_window_bounds=padded_window_bounds, + full_raster_labels=padded_raster_labels, + input_size=self.input_size, + nodata_value=self.nodata_value, + base_st_geometry=padded_st_geometry, + task_id=str(annotation_task.task_id), + ) + print(f"Created {len(labeled_windows)} labeled windows") + + + return labeled_windows \ No newline at end of file diff --git a/olmoearth_run_data/mangrove/model.yaml b/olmoearth_run_data/mangrove/model.yaml index 5658767..da713ca 100644 --- a/olmoearth_run_data/mangrove/model.yaml +++ b/olmoearth_run_data/mangrove/model.yaml @@ -9,7 +9,8 @@ model: encoder: - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth init_args: - model_id: OLMOEARTH_V1_BASE + model_path: ${EXTRA_FILES_PATH} + embedding_size: 768 patch_size: 2 decoders: mangrove_classification: diff --git a/olmoearth_run_data/yemen_crop_mapping/class_map.json b/olmoearth_run_data/yemen_crop_mapping/class_map.json new file mode 100644 index 0000000..256c4b8 --- /dev/null +++ b/olmoearth_run_data/yemen_crop_mapping/class_map.json @@ -0,0 +1,11 @@ +{ + "orchards": 0, + "coffee": 1, + "inactive_cropland": 2, + "cereals": 3, + "not_cropland": 4, + "greenhouse": 5, + "fodder": 6, + "mixed_other": 7, + "qat": 8 +} \ No newline at end of file diff --git a/olmoearth_run_data/yemen_crop_mapping/dataset.json b/olmoearth_run_data/yemen_crop_mapping/dataset.json new file mode 100644 index 0000000..0fbcdce --- /dev/null +++ b/olmoearth_run_data/yemen_crop_mapping/dataset.json @@ -0,0 +1,63 @@ +{ + "layers": { + "label": { + "band_sets": [ + { + "bands": [ + "crop_land" + ], + "dtype": "int32" + } + ], + "type": "raster" + }, + "output": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "float32" + } + ], + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "ingest": false, + "init_args": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "sort_by": "eo:cloud_cover" + }, + "query_config": { + "max_matches": 12, + "min_matches": 12, + "period_duration": "30d", + "space_mode": "PER_PERIOD_MOSAIC" + } + }, + "type": "raster" + } + } + } diff --git a/olmoearth_run_data/yemen_crop_mapping/model.yaml b/olmoearth_run_data/yemen_crop_mapping/model.yaml new file mode 100644 index 0000000..e4bbdd2 --- /dev/null +++ b/olmoearth_run_data/yemen_crop_mapping/model.yaml @@ -0,0 +1,433 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_BASE + patch_size: 1 + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 9 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.001 + plateau: true + plateau_factor: 0.5 + plateau_patience: 2 + plateau_min_lr: 0.00005 + plateau_cooldown: 20 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + targets: + data_type: "raster" + layers: ["label"] + bands: ["crop_land"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 9 + nodata_value: 255 + metric_kwargs: + average: "micro" + other_metrics: + # Macro and weighted averages + macro_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: "macro" + macro_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: "macro" + macro_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: "macro" + weighted_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: "weighted" + weighted_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: "weighted" + weighted_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: "weighted" + # Per-class metrics: orchards (class 0) + orchards_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 0 + orchards_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 0 + orchards_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 0 + # Per-class metrics: coffee (class 1) + coffee_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 1 + coffee_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 1 + coffee_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 1 + # Per-class metrics: inactive_cropland (class 2) + inactive_cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 2 + inactive_cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 2 + inactive_cropland_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 2 + # Per-class metrics: cereals (class 3) + cereals_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 3 + cereals_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 3 + cereals_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 3 + # Per-class metrics: not_cropland (class 4) + not_cropland_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 4 + not_cropland_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 4 + not_cropland_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 4 + # Per-class metrics: greenhouse (class 5) + greenhouse_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 5 + greenhouse_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 5 + greenhouse_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 5 + # Per-class metrics: fodder (class 6) + fodder_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 6 + fodder_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 6 + fodder_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 6 + # Per-class metrics: mixed_other (class 7) + mixed_other_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 7 + mixed_other_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 7 + mixed_other_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 7 + # Per-class metrics: qat (class 8) + qat_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 9 + average: null + class_idx: 8 + qat_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 9 + average: null + class_idx: 8 + qat_f1: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassF1Score + init_args: + num_classes: 9 + average: null + class_idx: 8 + input_mapping: + segment: + targets: targets + batch_size: 16 # could I use a bigger batch size? + num_workers: ${NUM_WORKERS} + default_config: + patch_size: 16 # input_size for the model + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + transforms: + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "target/segment/classes", "target/segment/valid"] + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + groups: ["spatial_split"] + tags: + split: "train" + val_config: + load_all_patches: true + groups: ["spatial_split"] + tags: + split: "val" + test_config: + load_all_patches: true + groups: ["spatial_split"] + tags: + split: "test" + predict_config: + load_all_patches: true + patch_size: 16 + skip_targets: true +trainer: + max_epochs: 800 + strategy: + class_path: lightning.pytorch.strategies.DDPStrategy + init_args: + find_unused_parameters: false + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: ${WANDB_PROJECT} + name: ${WANDB_NAME} + entity: ${WANDB_ENTITY} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_segment/macro_f1 + mode: max + dirpath: ${TRAINER_DATA_PATH} + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 100 + unfreeze_lr_factor: 5 + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: ${DATASET_PATH} + output_layer: ${PREDICTION_OUTPUT_LAYER} + selector: ["segment"] +# EXTRA_FILES_PATH: ${EXTRA_FILES_PATH} \ No newline at end of file diff --git a/olmoearth_run_data/yemen_crop_mapping/olmoearth_run.yaml b/olmoearth_run_data/yemen_crop_mapping/olmoearth_run.yaml new file mode 100644 index 0000000..34f2d86 --- /dev/null +++ b/olmoearth_run_data/yemen_crop_mapping/olmoearth_run.yaml @@ -0,0 +1,86 @@ +window_prep: + sampler: + class_path: olmoearth_run.runner.tools.samplers.noop_sampler.NoopSampler + labeled_window_preparer: + class_path: olmoearth_projects.projects.yemen_crop.dense_polygon_to_raster.DensePolygonToRasterWindowPreparer + init_args: + window_resolution: 10.0 + input_size: 64 + data_splitter: + class_path: olmoearth_run.runner.tools.data_splitters.spatial_data_splitter.SpatialDataSplitter + init_args: + train_prop: 0.8 + val_prop: 0.20 + test_prop: 0.0 + grid_size: 128 # in pixels + label_layer: "label" + group_name: "spatial_split" + split_property: "split" + +partition_strategies: + partition_request_geometry: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 1.0 + + prepare_window_geometries: + class_path: olmoearth_run.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 1024 + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + +postprocessing_strategies: + process_dataset: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 10 # does this need to align with the other nodata values for training + + process_partition: + class_path: olmoearth_run.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + init_args: + nodata_value: 10 + + process_window: + class_path: olmoearth_run.runner.tools.postprocessors.noop_raster.NoopRaster + +inference_results_config: + data_type: RASTER + classification_fields: + - property_name: segment + band_index: 1 + allowed_values: + - value: 0 + label: orchards + color: [218, 112, 214] # purple + - value: 1 + label: coffee + color: [165, 42, 42] # brown + - value: 2 + label: inactive_cropland + color: [128, 128, 128] # gray + - value: 3 + label: cereals + color: [255, 215, 0] # gold + - value: 4 + label: not_cropland + color: [70, 130, 180] # steel blue + - value: 5 + label: greenhouse + color: [0, 255, 127] # spring green + - value: 6 + label: fodder + color: [255, 99, 71] # tomato + - value: 7 + label: mixed_other + color: [0, 191, 255] # deep sky blue + - value: 8 + label: qat + color: [189, 183, 107] # dark khaki + # detection_objects: null + # regression_fields: null diff --git a/olmoearth_run_data/yemen_crop_mapping/prediction_request_geometry.geojson b/olmoearth_run_data/yemen_crop_mapping/prediction_request_geometry.geojson new file mode 100644 index 0000000..bdae79c --- /dev/null +++ b/olmoearth_run_data/yemen_crop_mapping/prediction_request_geometry.geojson @@ -0,0 +1,55 @@ +{ + "type": "FeatureCollection", + "name": "aois_for_ai2", + "crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, + "features": [ + { + "type": "Feature", + "properties": { + "fid": 2, + "region_name": "sibah", + "expected_crops": "coffee,qat", + "area_sqkm": 21, + "area_sqkm2": 21, + "oe_start_time": "2024-01-01T00:00:00Z", + "oe_end_time": "2024-12-31T00:00:00Z" + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [45.327819031185044, 13.820117498811239], + [45.326799895172805, 13.862021573532653], + [45.368211109923109, 13.862981105872963], + [45.369222903216922, 13.821074019067382], + [45.327819031185044, 13.820117498811239] + ] + ] + } + }, + { + "type": "Feature", + "properties": { + "fid": 4, + "region_name": "sarar", + "expected_crops": "cereals,fodder", + "area_sqkm": 26, + "area_sqkm2": 26, + "oe_start_time": "2024-01-01T00:00:00Z", + "oe_end_time": "2024-12-31T00:00:00Z" + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [45.280264440081169, 13.612629014160101], + [45.279150579341554, 13.658755912721082], + [45.325378235906378, 13.659821088266645], + [45.326483214506162, 13.613690458477992], + [45.280264440081169, 13.612629014160101] + ] + ] + } + } + ] +} diff --git a/scripts/polygons_to_annotation_features.py b/scripts/polygons_to_annotation_features.py new file mode 100644 index 0000000..c4df4e8 --- /dev/null +++ b/scripts/polygons_to_annotation_features.py @@ -0,0 +1,200 @@ +""" +Convert a GeoJSON of labeled polygons (e.g. Dhamar.geojson) into the two files that +`olmoearth_run` expects for fine-tuning: + +- annotation_task_features.geojson +- annotation_features.geojson + +`olmoearth_run` expects: +- Each task is a GeoJSON Feature with: + - properties.oe_annotations_task_id (UUID) + - properties.oe_start_time / properties.oe_end_time (ISO-8601 datetimes) + - geometry: Polygon/MultiPolygon (task boundary) +- Each annotation is a GeoJSON Feature with: + - properties.oe_annotations_task_id (UUID) (must match some task above) + - properties.oe_labels: dict[str, int|float|None] (we use {"category": class_id}) + - optional oe_start_time / oe_end_time + - geometry: Polygon/MultiPolygon (annotation geometry; we clip to the task boundary) + + +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Iterable + +from shapely.geometry import mapping as shapely_mapping +from shapely.geometry import shape as shapely_shape +from shapely.geometry import box as shapely_box +from shapely.geometry.base import BaseGeometry +from shapely.geometry import Polygon, MultiPolygon +from shapely.ops import unary_union +from olmoearth_run.runner.models.training.annotation_features import AnnotationTaskFeature, AnnotationFeature, AnnotationTaskFeatureProperties, AnnotationFeatureProperties +from pydantic import BaseModel, Field, ConfigDict +from geojson_pydantic.features import FeatureCollection +from olmoearth_run.shared.models.model_stage_paths import ( + ANNOTATION_FEATURES_FILE_NAME, + ANNOTATION_TASK_FEATURES_FILE_NAME) + +# 9-class Yemen crop-type mapping. +# NOTE: These are 0-indexed class IDs (common for segmentation). +YEMEN_CROP_CLASSES: list[str] = [ + "orchards", + "coffee", + "inactive_cropland", + "cereals", + "not_cropland", + "greenhouse", + "fodder", + "mixed_other", + "qat", +] +CLASS_TO_ID: dict[str, int] = {name: i for i, name in enumerate(YEMEN_CROP_CLASSES)} + +LABEL_KEY = "crop_land" +# I want to do the very simple thing of creating the task_features file with the entire geometry from one geojson in a single task +class YemenCropLabelsProperties(BaseModel): + start_time: datetime + category: int | str + +class YemenCropFeature(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type: str = Field(default="Feature") + geometry: Polygon | MultiPolygon + properties: YemenCropLabelsProperties + +class YemenCropLabels(BaseModel): + type: str = Field(default="FeatureCollection") + features: list[YemenCropFeature] + +# The start and end time are not the period when the label is valid but are more of a contextual choice that may eed to be configured and interjected + + +def dense_polygons_to_annotation_features(geojson_path: Path) -> tuple[dict, dict]: + """ + Convert a GeoJSON of dense polygons into the two files that `olmoearth_run` expects for fine-tuning: + - annotation_task_features.geojson + - annotation_features.geojson + """ + + # Read the GeoJSON file + with open(geojson_path, 'r') as f: + labels = json.load(f) + + label_features = labels["features"] + + # ensure all the start and end times are the same + # For now we will just use the first one + oe_start_time = label_features[0]["properties"]["start_time"] + oe_end_time = label_features[0]["properties"]["end_time"] + + all_features = [shapely_shape(feature["geometry"]) for feature in label_features] + merged = unary_union(all_features) + # Ensure result is a MultiPolygon for consistency + merged_geometry = merged if isinstance(merged, MultiPolygon) else MultiPolygon([merged]) + task_id = uuid.uuid4() + annotation_task_properties = AnnotationTaskFeatureProperties( + oe_annotations_task_id=task_id, + oe_start_time=oe_start_time, + oe_end_time=oe_end_time, + ) + annotation_task_feature = AnnotationTaskFeature( + type="Feature", + geometry=merged_geometry, + properties=annotation_task_properties, + ) + + annotation_features = [] + for label_feature in label_features: + label_geometry = shapely_shape(label_feature["geometry"]) + label_idx = CLASS_TO_ID[label_feature["properties"][LABEL_KEY]] + + annotation_properties = AnnotationFeatureProperties( + oe_labels={LABEL_KEY: label_idx}, + oe_annotations_task_id=task_id, + oe_start_time=oe_start_time, + oe_end_time=oe_end_time, + ) + + annotation_feature = AnnotationFeature( + type="Feature", + geometry=label_geometry, + properties=annotation_properties, + ) + annotation_features.append(annotation_feature) + return annotation_task_feature, annotation_features + + +def process_input(input_path: Path) -> tuple[list[AnnotationTaskFeature], list[AnnotationFeature]]: + """ + Process either a single GeoJSON file or a directory of GeoJSON files. + Returns combined lists of task features and annotation features. + """ + all_task_features: list[AnnotationTaskFeature] = [] + all_annotation_features: list[AnnotationFeature] = [] + + if input_path.is_file(): + geojson_files = [input_path] + elif input_path.is_dir(): + geojson_files = sorted(input_path.glob("*.geojson")) + if not geojson_files: + raise ValueError(f"No .geojson files found in directory: {input_path}") + else: + raise ValueError(f"Input path does not exist: {input_path}") + + for geojson_path in geojson_files: + print(f"Processing: {geojson_path.name}") + task_feature, annotation_features = dense_polygons_to_annotation_features(geojson_path) + print("Features created!") + all_task_features.append(task_feature) + all_annotation_features.extend(annotation_features) + + print(f"Processed {len(geojson_files)} file(s): {len(all_task_features)} tasks, {len(all_annotation_features)} annotations") + return all_task_features, all_annotation_features + + +def write_feature_collections( + task_features: list[AnnotationTaskFeature], + annotation_features: list[AnnotationFeature], + output_dir: Path, +): + """ + Write the task features and annotation features to GeoJSON FeatureCollection files. + """ + os.makedirs(output_dir, exist_ok=True) + + # Write annotation features + annotation_feature_collection = FeatureCollection(type="FeatureCollection", features=annotation_features) + annotation_output_path = output_dir / ANNOTATION_FEATURES_FILE_NAME + with open(annotation_output_path, 'w') as f: + json.dump(annotation_feature_collection.model_dump(mode='json'), f) + + # Write task features + task_feature_collection = { + "type": "FeatureCollection", + "features": [tf.model_dump(mode='json') for tf in task_features] + } + task_output_path = output_dir / ANNOTATION_TASK_FEATURES_FILE_NAME + with open(task_output_path, 'w') as f: + json.dump(task_feature_collection, f) + + print(f"Wrote {ANNOTATION_FEATURES_FILE_NAME} and {ANNOTATION_TASK_FEATURES_FILE_NAME} to {output_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert GeoJSON(s) of labeled polygons into the two files that `olmoearth_run` expects for fine-tuning") + parser.add_argument("input", type=Path, help="Input GeoJSON file or directory containing .geojson files") + parser.add_argument("output_dir", type=Path, help="The output directory for the annotation features") + args = parser.parse_args() + + task_features, annotation_features = process_input(args.input) + write_feature_collections(task_features, annotation_features, args.output_dir) \ No newline at end of file