Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define XarraySource.from_stac() for more convenient creation of an XarraySource from a STAC Item or ItemCollection #2061

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rastervision.core.data.utils import parse_array_slices_Nd, fill_overflow

if TYPE_CHECKING:
from pystac import Item, ItemCollection
from rastervision.core.data import RasterTransformer, CRSTransformer

log = logging.getLogger(__name__)
Expand All @@ -26,7 +27,6 @@ def __init__(self,
crs_transformer: 'CRSTransformer',
raster_transformers: List['RasterTransformer'] = [],
channel_order: Optional[Sequence[int]] = None,
num_channels_raw: Optional[int] = None,
bbox: Optional[Box] = None,
temporal: bool = False):
"""Constructor.
Expand Down Expand Up @@ -63,8 +63,7 @@ def __init__(self,
self.ndim = data_array.ndim
self._crs_transformer = crs_transformer

if num_channels_raw is None:
num_channels_raw = len(data_array.band)
num_channels_raw = len(data_array.band)
if channel_order is None:
channel_order = np.arange(num_channels_raw, dtype=int)
else:
Expand Down Expand Up @@ -93,6 +92,78 @@ def __init__(self,
raster_transformers=raster_transformers,
bbox=bbox)

@classmethod
def from_stac(
cls,
item_or_item_collection: Union['Item', 'ItemCollection'],
raster_transformers: List['RasterTransformer'] = [],
channel_order: Optional[Sequence[int]] = None,
bbox: Optional[Box] = None,
bbox_map_coords: Optional[Box] = None,
temporal: bool = False,
allow_streaming: bool = False,
stackstac_args: dict = dict(rescale=False)) -> 'XarraySource':
"""Construct an ``XarraySource`` from a STAC Item or ItemCollection.

Args:
item_or_item_collection: STAC Item or ItemCollection.
raster_transformers: RasterTransformers to use to transform chips
after they are read.
channel_order: List of indices of channels to extract from raw
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords of the
form (ymin, xmin, ymax, xmax). Useful for cropping the raster
source so that only part of the raster is read from. Mutually
exclusive with ``bbox``. Defaults to ``None``.
temporal: If True, data_array is expected to have a "time"
dimension and the chips returned will be of shape (T, H, W, C).
allow_streaming: If False, load the entire DataArray into memory.
Defaults to True.
stackstac_args: Optional arguments to pass to stackstac.stack().
"""
import stackstac

data_array = stackstac.stack(item_or_item_collection, **stackstac_args)

if not temporal and 'time' in data_array.dims:
if len(data_array.time) > 1:
raise ValueError('temporal=False but len(data_array.time) > 1')
data_array = data_array.isel(time=0)

if not allow_streaming:
from humanize import naturalsize
log.info('Loading the full DataArray into memory '
f'({naturalsize(data_array.nbytes)}).')
data_array.load()

crs_transformer = RasterioCRSTransformer(
transform=data_array.transform, image_crs=data_array.crs)

if bbox is not None:
if bbox_map_coords is not None:
raise ValueError('Specify either bbox or bbox_map_coords, '
'but not both.')
bbox = Box(*bbox)
elif bbox_map_coords is not None:
bbox_map_coords = Box(*bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
else:
bbox = None

raster_source = XarraySource(
data_array,
crs_transformer=crs_transformer,
raster_transformers=raster_transformers,
channel_order=channel_order,
bbox=bbox,
temporal=temporal)
return raster_source

@property
def shape(self) -> Tuple[int, int, int]:
"""Shape of the raster as a (height, width, num_channels) tuple."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import logging

from rastervision.pipeline.config import Field, register_config
from rastervision.core.box import Box
from rastervision.core.data.raster_source.raster_source_config import (
RasterSourceConfig)
from rastervision.core.data.crs_transformer import RasterioCRSTransformer
from rastervision.core.data.raster_source.stac_config import (
STACItemConfig, STACItemCollectionConfig)
from rastervision.core.data.raster_source.xarray_source import (XarraySource)
from rastervision.core.data.raster_source.xarray_source import XarraySource

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,43 +36,17 @@ class XarraySourceConfig(RasterSourceConfig):
def build(self,
tmp_dir: Optional[str] = None,
use_transformers: bool = True) -> XarraySource:
import stackstac

item_or_item_collection = self.stac.build()
data_array = stackstac.stack(item_or_item_collection,
**self.stackstac_args)

if not self.temporal and 'time' in data_array.dims:
if len(data_array.time) > 1:
raise ValueError('temporal=False but len(data_array.time) > 1')
data_array = data_array.isel(time=0)

if not self.allow_streaming:
from humanize import naturalsize
log.info('Loading the full DataArray into memory '
f'({naturalsize(data_array.nbytes)}).')
data_array.load()

crs_transformer = RasterioCRSTransformer(
transform=data_array.transform, image_crs=data_array.crs)
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])

if self.bbox is not None:
if self.bbox_map_coords is not None:
log.info('Using bbox and ignoring bbox_map_coords.')
bbox = Box(*self.bbox)
elif self.bbox_map_coords is not None:
bbox_map_coords = Box(*self.bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
else:
bbox = None

raster_source = XarraySource(
data_array,
crs_transformer=crs_transformer,
raster_source = XarraySource.from_stac(
item_or_item_collection,
raster_transformers=raster_transformers,
channel_order=self.channel_order,
bbox=bbox,
temporal=self.temporal)
bbox=self.bbox,
bbox_map_coords=self.bbox_map_coords,
temporal=self.temporal,
allow_streaming=self.allow_streaming,
stackstac_args=self.stackstac_args,
)
return raster_source