diff --git a/src/pystac/asset.py b/src/pystac/asset.py index c6c6f63d5..fd4557422 100644 --- a/src/pystac/asset.py +++ b/src/pystac/asset.py @@ -6,7 +6,7 @@ from typing_extensions import deprecated from .band import Band -from .common_metadata import DataValue, Instrument +from .common_metadata import DataValue, DateTime, Instrument from .media_type import MediaType from .utils import ( get_absolute_href, @@ -21,7 +21,7 @@ from .stac_object import STACObject -class ItemAsset(DataValue, Instrument): +class ItemAsset(DateTime, DataValue, Instrument): def __init__( self, *, @@ -41,6 +41,7 @@ def __init__( else: self.bands = None self.extra_fields: dict[str, Any] = kwargs + self.extra_fields.update(DateTime.from_str(kwargs)) @override def __eq__(self, other: Any) -> bool: @@ -71,6 +72,7 @@ def clone[T: ItemAsset](self: T) -> T: def to_dict(self) -> dict[str, Any]: data: dict[str, Any] = copy.deepcopy(self.extra_fields) + data.update(DateTime.to_str(self.extra_fields)) data["title"] = self.title data["description"] = self.description data["type"] = self.type diff --git a/src/pystac/collection.py b/src/pystac/collection.py index df5578ebb..34b2177ee 100644 --- a/src/pystac/collection.py +++ b/src/pystac/collection.py @@ -2,7 +2,17 @@ import copy import datetime as dt -from typing import Any, ClassVar, TypedDict, cast, override +import warnings +from collections.abc import Iterable +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Self, + TypedDict, + cast, + override, +) from typing_extensions import deprecated @@ -12,7 +22,11 @@ from .container import Container from .link import Link from .provider import Provider -from .utils import datetime_to_str +from .utils import to_datetime_str + +if TYPE_CHECKING: + from .item import Item + from .item_collection import ItemCollection SpatialExtentBboxType = list[float | int] | list[list[float | int]] TemporalExtentIntervalType = ( @@ -64,13 +78,22 @@ def __init__( self.bands: list[Band] | None = [Band.try_from(band) for band in bands] else: self.bands = None + + if extent is None: + warnings.warn( + "Collection is missing extent, setting default spatial and " + "temporal extents" + ) self.extent: Extent = Extent.try_from(extent) self.summaries: dict[str, Any] | None = summaries - self.assets: dict[str, Asset] = ( - {key: Asset.try_from(value) for (key, value) in assets.items()} - if assets is not None - else {} - ) + + self.assets: dict[str, Asset] = {} + if assets is not None: + for key, value in assets.items(): + asset = Asset.try_from(value) + asset.set_owner(self) + self.assets[key] = asset + self.item_assets: dict[str, ItemAsset] | None = ( {key: ItemAsset.try_from(value) for (key, value) in item_assets.items()} if item_assets is not None @@ -79,16 +102,77 @@ def __init__( @override @classmethod - def from_dict( - cls, + def from_dict[T: Collection]( + cls: type[T], data: dict[str, Any], preserve_dict: bool = True, migrate: bool | None = None, root: Container | None = None, - ) -> Collection: + ) -> T: if preserve_dict: data = copy.deepcopy(data) - return Collection(**data) + return cls(**data) + + @classmethod + def from_items[T: Collection]( + cls: type[T], + items: Iterable[Item] | ItemCollection, + *, + id: str | None = None, + ) -> T: + """Create a :class:`Collection` from iterable of items or an + :class:`~pystac.ItemCollection`. + + Will try to pull collection attributes from + :attr:`~pystac.ItemCollection.extra_fields` and items when possible. + + Args: + items : Iterable of :class:`~pystac.Item` instances to include in the + :class:`Collection`. This can be a :class:`~pystac.ItemCollection`. + id : Identifier for the collection. If not set, must be available on the + items and they must all match. + """ + from .item_collection import ItemCollection + + if id is None: + values = {item.collection_id for item in items} + if len(values) == 1: + id = next(iter(values)) + if id is None: + raise ValueError( + "Collection id must be defined. Either by specifying collection_id " + "on every item, or as a keyword argument to this function." + ) + + kwargs: dict[str, Any] = {} + if isinstance(items, ItemCollection): + kwargs = copy.deepcopy(items.extra_fields) + + if "description" not in kwargs: + values = {item.properties.description for item in items} + if len(values) == 1: + kwargs["description"] = next(iter(values)) + + if "title" not in kwargs: + values = {item.properties.title for item in items} + if len(values) == 1: + kwargs["title"] = next(iter(values)) + + collection = cls( + id=id, + description=kwargs.pop("description"), + extent=Extent.from_items(items), + **kwargs, + ) + + _ = collection.add_items(items) + + return collection + + @property + @deprecated("use .type instead") + def STAC_OBJECT_TYPE(self) -> STAC_OBJECT_TYPE: + return self.type @override def set_self_href(self, href: str | None) -> None: @@ -127,15 +211,23 @@ def to_dict( } return data + def update_extent_from_items(self) -> None: + """ + Update datetime and bbox based on all items to a single bbox and time window. + """ + self.extent = Extent.from_items(self.get_items(recursive=True)) + class Extent: def __init__( self, spatial: SpatialExtent | SpatialExtentDict | None = None, temporal: TemporalExtent | TemporalExtentDict | None = None, + **kwargs: Any, ): self.spatial: SpatialExtent = SpatialExtent.try_from(spatial) self.temporal: TemporalExtent = TemporalExtent.try_from(temporal) + self.extra_fields: dict[str, Any] = kwargs @classmethod def try_from(cls, extent: Extent | dict[str, Any] | None) -> Extent: @@ -147,21 +239,120 @@ def try_from(cls, extent: Extent | dict[str, Any] | None) -> Extent: return Extent() @classmethod - def from_dict(cls, data: dict[str, Any]) -> Extent: - return Extent(**data) + def from_dict[T: Extent](cls: type[T], data: dict[str, Any]) -> T: + return cls(**data) + + @classmethod + def from_items(cls: type[Self], items: Iterable[Item]) -> Extent: + """Create an Extent based on the datetimes and bboxes of a list of items. + + Args: + items : A list of items to derive the extent from. + + Returns: + Extent: An Extent that spatially and temporally covers all of the + given items. + """ + bounds_values: list[list[float]] = [ + [float("inf")], + [float("inf")], + [float("-inf")], + [float("-inf")], + ] + datetimes: list[dt.datetime] = [] + starts: list[dt.datetime] = [] + ends: list[dt.datetime] = [] + + for item in items: + if item.bbox is not None: + for i in range(0, 4): + bounds_values[i].append(item.bbox[i]) + if item.datetime is not None: + datetimes.append(item.datetime) + if item.properties.start_datetime is not None: + starts.append(item.properties.start_datetime) + if item.properties.end_datetime is not None: + ends.append(item.properties.end_datetime) + + if not any(datetimes + starts): + start_timestamp = None + else: + start_timestamp = min( + [ + d if d.tzinfo else d.replace(tzinfo=dt.UTC) + for d in datetimes + starts + ] + ) + if not any(datetimes + ends): + end_timestamp = None + else: + end_timestamp = max( + [d if d.tzinfo else d.replace(tzinfo=dt.UTC) for d in datetimes + ends] + ) + + spatial = SpatialExtent( + [ + [ + min(bounds_values[0]), + min(bounds_values[1]), + max(bounds_values[2]), + max(bounds_values[3]), + ] + ] + ) + temporal = TemporalExtent( + [ + [ + to_datetime_str(start_timestamp), + to_datetime_str(end_timestamp), + ] + ] + ) + + return Extent(spatial=spatial, temporal=temporal) + + def clone[T: Extent](self: T) -> T: + return self.from_dict(self.to_dict()) def to_dict(self) -> dict[str, Any]: + data = copy.deepcopy(self.extra_fields) return { "spatial": self.spatial.to_dict(), "temporal": self.temporal.to_dict(), + **data, } class TemporalExtent: - def __init__(self, interval: list[list[str | None]] | None = None) -> None: - self.interval: list[list[str | None]] = interval or [ - [datetime_to_str(dt.datetime.now()), None] - ] + def __init__( + self, + interval: TemporalExtentIntervalType | str | dt.datetime | None = None, + **kwargs: Any, + ) -> None: + if interval is None and "intervals" in kwargs: + warnings.warn( + "intervals is deprecated and will be removed in a future version. " + "Use interval instead", + FutureWarning, + ) + interval = kwargs.pop("intervals") + + self.interval: list[list[str | None]] + if interval is None: + self.interval = [[None, None]] + elif isinstance(interval, (str, dt.datetime)): + self.interval = [to_interval([interval, None])] + elif isinstance(interval[0], list): + self.interval = [ + to_interval(cast(list[dt.datetime | str | None], dates)) + for dates in interval + ] + else: + self.interval = [ + to_interval(cast(list[dt.datetime | str | None], interval)) + ] + + self.extra_fields: dict[str, Any] = kwargs @classmethod def try_from( @@ -172,30 +363,44 @@ def try_from( return data elif not data: return TemporalExtent() - elif isinstance(data["interval"][0], list): - return TemporalExtent( - [ - to_interval(cast(list[dt.datetime | str | None], interval)) - for interval in data["interval"] - ] - ) else: - return TemporalExtent( - [to_interval(cast(list[dt.datetime | str | None], data["interval"]))] - ) + return TemporalExtent(**data) + + @property + @deprecated("Use interval instead") + def intervals(self) -> list[list[str | None]]: + return self.interval @classmethod - @deprecated("Use default initializer instead") + @deprecated("Use `TemporalExtent(dt.datetime.now())` instead") def from_now(cls) -> TemporalExtent: - return TemporalExtent() + return TemporalExtent(interval=dt.datetime.now()) def to_dict(self) -> dict[str, Any]: - return {"interval": self.interval} + data = copy.deepcopy(self.extra_fields) + return {"interval": self.interval, **data} class SpatialExtent: - def __init__(self, bbox: list[list[float | int]] | None = None) -> None: - self.bbox: list[list[float | int]] = bbox or [[-180, -90, 180, 90]] + def __init__( + self, bbox: SpatialExtentBboxType | None = None, **kwargs: Any + ) -> None: + if bbox is None and "bboxes" in kwargs: + warnings.warn( + "bboxes is deprecated and will be removed in a future version. " + "Use bbox instead", + FutureWarning, + ) + bbox = kwargs.pop("bboxes") + + self.bbox: list[list[float | int]] + if bbox is None: + self.bbox = [[-180, -90, 180, 90]] + elif isinstance(bbox[0], list): + self.bbox = cast(list[list[float | int]], bbox) + else: + self.bbox = [cast(list[float | int], bbox)] + self.extra_fields: dict[str, Any] = kwargs @classmethod def try_from( @@ -206,30 +411,20 @@ def try_from( return data elif not data: return SpatialExtent() - elif isinstance(data["bbox"][0], list): - return SpatialExtent(cast(list[list[float | int]], data["bbox"])) else: - return SpatialExtent([cast(list[float | int], data["bbox"])]) + return SpatialExtent(**data) - @classmethod - @deprecated("Use try_from instead") - def from_coordinates(cls, coordinates: Any) -> SpatialExtent: - return SpatialExtent.try_from({"bbox": coordinates}) + @property + @deprecated("Use bbox instead") + def bboxes(self) -> list[list[float | int]]: + return self.bbox def to_dict(self) -> dict[str, Any]: - return {"bbox": self.bbox} + data = copy.deepcopy(self.extra_fields) + return {"bbox": self.bbox, **data} def to_interval(interval: list[dt.datetime | str | None]) -> list[str | None]: if len(interval) != 2: raise ValueError("Interval must have exactly two elements") return [to_datetime_str(interval[0]), to_datetime_str(interval[1])] - - -def to_datetime_str(datetime: dt.datetime | str | None) -> str | None: - if datetime is None: - return None - elif isinstance(datetime, str): - return datetime - else: - return datetime_to_str(datetime) diff --git a/src/pystac/common_metadata.py b/src/pystac/common_metadata.py index 134b36d96..a8bc40afd 100644 --- a/src/pystac/common_metadata.py +++ b/src/pystac/common_metadata.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Protocol, override from .provider import Provider -from .utils import datetime_to_str, str_to_datetime +from .utils import from_datetime_str, to_datetime_str if TYPE_CHECKING: from .asset import Asset @@ -71,38 +71,20 @@ class DateTime(Protocol): @property def datetime(self) -> dt.datetime | None: """See the Item Specification Fields for more information.""" - value: str | None = self.extra_fields.get("datetime") - datetime: dt.datetime | None = None - if value is not None: - datetime = str_to_datetime(value) - return datetime + return self.extra_fields.get("datetime") @datetime.setter def datetime(self, value: str | dt.datetime | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] - datetime: str | None = None - if isinstance(value, str): - datetime = value - elif isinstance(value, dt.datetime): - datetime = datetime_to_str(value) - self.extra_fields["datetime"] = datetime + self.extra_fields["datetime"] = from_datetime_str(value) @property def created(self) -> dt.datetime | None: """Creation date and time of the corresponding STAC entity or Asset, in UTC.""" - value: str | None = self.extra_fields.get("created") - datetime: dt.datetime | None = None - if value is not None: - datetime = str_to_datetime(value) - return datetime + return self.extra_fields.get("created") @created.setter def created(self, value: str | dt.datetime | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] - datetime: str | None = None - if isinstance(value, str): - datetime = value - elif isinstance(value, dt.datetime): - datetime = datetime_to_str(value) - self.extra_fields["created"] = datetime + self.extra_fields["created"] = from_datetime_str(value) @property def updated(self) -> dt.datetime | None: @@ -110,56 +92,59 @@ def updated(self) -> dt.datetime | None: Date and time the corresponding STAC entity or Asset was updated last, in UTC. """ - value: str | None = self.extra_fields.get("updated") - datetime: dt.datetime | None = None - if value is not None: - datetime = str_to_datetime(value) - return datetime + return self.extra_fields.get("updated") @updated.setter def updated(self, value: str | dt.datetime | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] - datetime: str | None = None - if isinstance(value, str): - datetime = value - elif isinstance(value, dt.datetime): - datetime = datetime_to_str(value) - self.extra_fields["updated"] = datetime + self.extra_fields["updated"] = from_datetime_str(value) @property def start_datetime(self) -> dt.datetime | None: """The first or start date and time for the resource, in UTC.""" - value: str | None = self.extra_fields.get("start_datetime") - datetime: dt.datetime | None = None - if value is not None: - datetime = str_to_datetime(value) - return datetime + return self.extra_fields.get("start_datetime") @start_datetime.setter def start_datetime(self, value: str | dt.datetime | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] - datetime: str | None = None - if isinstance(value, str): - datetime = value - elif isinstance(value, dt.datetime): - datetime = datetime_to_str(value) - self.extra_fields["start_datetime"] = datetime + self.extra_fields["start_datetime"] = from_datetime_str(value) @property def end_datetime(self) -> dt.datetime | None: """The last or end date and time for the resource, in UTC.""" - value: str | None = self.extra_fields.get("end_datetime") - datetime: dt.datetime | None = None - if value is not None: - datetime = str_to_datetime(value) - return datetime + return self.extra_fields.get("end_datetime") @end_datetime.setter def end_datetime(self, value: str | dt.datetime | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] - datetime: str | None = None - if isinstance(value, str): - datetime = value - elif isinstance(value, dt.datetime): - datetime = datetime_to_str(value) - self.extra_fields["end_datetime"] = datetime + self.extra_fields["end_datetime"] = from_datetime_str(value) + + @staticmethod + def from_str(extra_fields: dict[str, Any]) -> dict[str, dt.datetime | None]: + """Create new dict with iso strings replaced by dt.datetimes""" + data: dict[str, dt.datetime | None] = {} + for field in ( + "datetime", + "created", + "updated", + "start_datetime", + "end_datetime", + ): + if field in extra_fields: + data[field] = from_datetime_str(extra_fields[field]) + return data + + @staticmethod + def to_str(extra_fields: dict[str, Any]) -> dict[str, str | None]: + """Create new dict with dt.datetimes replaced by iso strings""" + data: dict[str, str | None] = {} + for field in ( + "datetime", + "created", + "updated", + "start_datetime", + "end_datetime", + ): + if field in extra_fields: + data[field] = to_datetime_str(extra_fields[field]) + return data class Licensing(Protocol): @@ -387,6 +372,7 @@ class CommonMetadata(Basics, DateTime, Licensing, Providers, Instrument, DataVal def __init__(self, /, object: Asset | Item): self.object = object + self.extra_fields.update(DateTime.from_str(self.extra_fields)) @property @override diff --git a/src/pystac/container.py b/src/pystac/container.py index 8d1ac12e7..674643a33 100644 --- a/src/pystac/container.py +++ b/src/pystac/container.py @@ -4,7 +4,7 @@ import os.path import warnings from abc import ABC -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterable, Iterator from typing import TYPE_CHECKING, Any from typing_extensions import deprecated @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .catalog import CatalogType # pyright: ignore[reportDeprecated] + from .collection import Collection from .href_generator import HrefGenerator from .item import Item from .stac_io import StacIO @@ -27,6 +28,9 @@ class Container(STACObject, ABC): + def get_item(self, id: str, recursive: bool = False) -> Item | None: + return next(self.get_items(id, recursive=recursive), None) + def get_items(self, *ids: str, recursive: bool = False) -> Iterator[Item]: for link in self.links: if link.is_item(): @@ -61,7 +65,7 @@ def add_item( return item_link - def add_items(self, items: list[Item]) -> list[Link]: + def add_items(self, items: Iterable[Item]) -> list[Link]: links: list[Link] = [] for item in items: links.append(self.add_item(item)) @@ -145,19 +149,54 @@ def remove_item(self, id: str) -> Item | None: def clear_items(self) -> None: self.links: list[Link] = [link for link in self.links if not link.is_item()] - def get_child(self, id: str, recursive: bool = False) -> Container | None: - for child in self.get_children(recursive=recursive): - if child.id == id: - return child + def get_collections( + self, *ids: str, recursive: bool = False, sort_links_by_id: bool = True + ) -> Iterator[Collection]: + from .collection import Collection - def get_children(self, recursive: bool = False) -> Iterator[Container]: - for link in self.links: - if link.is_child(): - stac_object = link.get_target(start_href=self._href, reader=self.reader) - if isinstance(stac_object, Container): + for child in self.get_children( + *ids, recursive=recursive, sort_links_by_id=sort_links_by_id + ): + if isinstance(child, Collection) and (not ids or child.id in ids): + yield child + + @deprecated( + "Get all collections is deprecated, use `get_collections(recursive=True)`" + ) + def get_all_collections(self) -> Iterator[Collection]: + yield from self.get_collections(recursive=True) + + def get_child( + self, id: str, recursive: bool = False, sort_links_by_id: bool = True + ) -> Container | None: + return next( + self.get_children( + id, recursive=recursive, sort_links_by_id=sort_links_by_id + ), + None, + ) + + def get_children( + self, *ids: str, recursive: bool = False, sort_links_by_id: bool = True + ) -> Iterator[Container]: + links = self.get_child_links() + if ids and sort_links_by_id: + links = sorted( + links, + key=lambda x: (href := x.get_href()) is None + or all(id not in href for id in ids), + ) + + for link in links: + stac_object = link.get_target(start_href=self._href, reader=self.reader) + + if isinstance(stac_object, Container): + if not ids or stac_object.id in ids: yield stac_object - if recursive: - yield from stac_object.get_children(recursive=recursive) + if recursive: + yield from stac_object.get_children( + *ids, recursive=recursive, sort_links_by_id=sort_links_by_id + ) def get_child_links(self) -> list[Link]: return [link for link in self.links if link.is_child()] diff --git a/src/pystac/item.py b/src/pystac/item.py index 53db877b7..107556048 100644 --- a/src/pystac/item.py +++ b/src/pystac/item.py @@ -24,7 +24,7 @@ from .geo_interface import GeoInterface from .link import Link from .stac_object import STACObject -from .utils import datetime_to_str, make_posix_style, str_to_datetime +from .utils import make_posix_style if TYPE_CHECKING: from .collection import Collection @@ -120,15 +120,15 @@ def datetime(self) -> dt.datetime | None: @deprecated("Get the datetime from the asset directly") def get_datetime(self, asset: Asset | None = None) -> dt.datetime | None: - if asset and (datetime := asset.extra_fields.get("datetime")): - return str_to_datetime(datetime) + if asset and (datetime := asset.datetime): + return datetime else: return self.properties.datetime @deprecated("Set the datetime on the asset directly") def set_datetime(self, datetime: dt.datetime, asset: Asset | None = None) -> None: if asset: - asset.extra_fields["datetime"] = datetime_to_str(datetime) + asset.datetime = datetime else: self.properties.datetime = datetime @@ -239,16 +239,15 @@ class Properties(Basics, DateTime, Licensing, Providers, Instrument): def __init__( self, *, - datetime: dt.datetime | str | None = None, - start_datetime: dt.datetime | str | None = None, - end_datetime: dt.datetime | str | None = None, bands: list[Band] | list[dict[str, Any]] | None = None, **kwargs: Any, ): self.extra_fields: dict[str, Any] = kwargs - self.datetime = datetime or dt.datetime.now(tz=dt.UTC) - self.start_datetime = start_datetime - self.end_datetime = end_datetime + self.extra_fields.update(DateTime.from_str(kwargs)) + + if not any([self.datetime, self.start_datetime, self.end_datetime]): + self.datetime = dt.datetime.now(tz=dt.UTC) + if bands is not None: self.bands: list[Band] | None = [Band.try_from(band) for band in bands] else: @@ -269,6 +268,12 @@ def __deepcopy__(self, memo: dict[int, Any]) -> Properties: **copy.deepcopy(self.extra_fields, memo), ) + def get(self, key: str, default: Any = None) -> Any: + try: + return self[key] + except KeyError: + return default + @classmethod def try_from( cls, @@ -287,6 +292,8 @@ def from_dict(cls, data: dict[str, Any]) -> Properties: def to_dict(self) -> dict[str, Any]: data = copy.deepcopy(self.extra_fields) + data.update(DateTime.to_str(self.extra_fields)) + if self.bands is not None: data["bands"] = [band.to_dict() for band in self.bands] - return {k: v for k, v in data.items() if v is not None} + return {k: v for k, v in data.items() if v is not None or k == "datetime"} diff --git a/src/pystac/item_collection.py b/src/pystac/item_collection.py index d922052a6..f47e5c10e 100644 --- a/src/pystac/item_collection.py +++ b/src/pystac/item_collection.py @@ -1,12 +1,24 @@ +from __future__ import annotations + +import copy from collections.abc import Iterator -from typing import override +from pathlib import Path +from typing import Any, Literal, TypedDict, override from .item import Item +from .reader import DEFAULT_READER, Reader +from .utils import make_absolute_href + + +class T_ItemCollection(TypedDict): + type: Literal["FeatureCollection"] + features: list[dict[str, Any]] class ItemCollection: - def __init__(self, items: list[Item]): + def __init__(self, items: list[Item], **kwargs: Any): self.items: list[Item] = items + self.extra_fields: dict[str, Any] = kwargs def __len__(self) -> int: return len(self.items) @@ -20,3 +32,41 @@ def __iter__(self) -> Iterator[Item]: @override def __repr__(self) -> str: return f"ItemCollection({self.items})" + + def to_dict(self) -> dict[str, Any]: + data = copy.deepcopy(self.extra_fields) + return { + "type": "FeatureCollection", + "features": [item.to_dict() for item in self.items], + **data, + } + + @classmethod + def try_from(cls, data: dict[str, Any] | ItemCollection) -> ItemCollection: + if isinstance(data, ItemCollection): + return data + else: + return cls(**data) + + @classmethod + def from_dict( + cls, + data: dict[str, Any], + preserve_dict: bool = True, + ) -> ItemCollection: + if preserve_dict: + data = copy.deepcopy(data) + + items = data.get("features", []) + extra_fields = {k: v for k, v in data.items() if k not in ("features", "type")} + + return cls(items=[Item.from_dict(item) for item in items], **extra_fields) + + @classmethod + def from_file( + cls, path: str | Path, reader: Reader = DEFAULT_READER + ) -> ItemCollection: + href = make_absolute_href(str(path)) + data = reader.get_json(href) + + return ItemCollection.from_dict(data) diff --git a/src/pystac/link.py b/src/pystac/link.py index 23eabcb4f..352151726 100644 --- a/src/pystac/link.py +++ b/src/pystac/link.py @@ -123,6 +123,11 @@ def get_href(self) -> str | None: def set_href(self, href: str) -> None: self._href = href + @property + @deprecated("href is deprecated, use .get_href()") + def href(self) -> str | None: + return self.get_href() + @property @deprecated("target is deprecated, either use .get_href() or .get_target()") def target(self) -> str | STACObject | None: diff --git a/src/pystac/utils.py b/src/pystac/utils.py index e2f9e6f9e..198d6edb0 100644 --- a/src/pystac/utils.py +++ b/src/pystac/utils.py @@ -464,6 +464,24 @@ def str_to_datetime(s: str) -> datetime: return dateutil.parser.isoparse(s) +def to_datetime_str(value: datetime | str | None) -> str | None: + if value is None: + return None + elif isinstance(value, str): + return value + else: + return datetime_to_str(value) + + +def from_datetime_str(value: datetime | str | None) -> datetime | None: + if value is None: + return None + elif isinstance(value, datetime): + return value + else: + return str_to_datetime(value) + + def now_in_utc() -> datetime: """Returns a datetime value of now with the UTC timezone applied""" return datetime.now(UTC) diff --git a/tests/v1/test_collection.py b/tests/v1/test_collection.py index 28ad7bb2d..de0473a6b 100644 --- a/tests/v1/test_collection.py +++ b/tests/v1/test_collection.py @@ -30,6 +30,10 @@ from .utils import ARBITRARY_BBOX, ARBITRARY_GEOM, TestCases + +pytestmark = pytest.mark.passing_v2 + + TEST_DATETIME = datetime(2020, 3, 14, 16, 32) @@ -62,6 +66,7 @@ def test_provider_to_from_dict() -> None: ) +@pytest.mark.xfail(reason="SpatialExtent.from_coordinates is not supported in pystac v2") def test_spatial_extent_from_coordinates() -> None: extent = SpatialExtent.from_coordinates(ARBITRARY_GEOM["coordinates"]) @@ -79,6 +84,7 @@ def test_read_eo_items_are_heritable() -> None: assert EOExtension.has_extension(item) +@pytest.mark.xfail(reason="Collection does not have a catalog type in pystac v2") def test_save_uses_previous_catalog_type() -> None: collection = TestCases.case_8() assert collection.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION @@ -92,6 +98,7 @@ def test_save_uses_previous_catalog_type() -> None: assert collection2.catalog_type == CatalogType.SELF_CONTAINED +@pytest.mark.xfail(reason="Collection does not have a catalog type in pystac v2") def test_clone_uses_previous_catalog_type() -> None: catalog = TestCases.case_8() assert catalog.catalog_type == CatalogType.SELF_CONTAINED @@ -207,12 +214,13 @@ def test_update_extents() -> None: assert [ [ - item2.common_metadata.start_datetime, + datetime_to_str(item2.common_metadata.start_datetime), base_extent.temporal.intervals[0][1], ] ] == collection.extent.temporal.intervals +@pytest.mark.xfail(reason="Supplying href in init is not supported in pystac v2") def test_supplying_href_in_init_does_not_fail() -> None: test_href = "http://example.com/collection.json" spatial_extent = SpatialExtent(bboxes=[ARBITRARY_BBOX]) @@ -226,6 +234,7 @@ def test_supplying_href_in_init_does_not_fail() -> None: assert collection.get_self_href() == test_href +@pytest.mark.xfail(reason="Caching is not implemented in pystac v2") def test_collection_with_href_caches_by_href() -> None: collection = pystac.Collection.from_file( TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") @@ -270,6 +279,7 @@ def test_get_assets() -> None: assert no_assets == {} +@pytest.mark.xfail(reason="Pystac v2 doesn't support summaries yet") def test_removing_optional_attributes() -> None: path = TestCases.get_path("data-files/collections/with-assets.json") with open(path) as file: @@ -299,7 +309,6 @@ def test_removing_optional_attributes() -> None: collection_as_dict = collection.to_dict() for key in ( "title", - "stac_extensions", "keywords", "providers", "summaries", @@ -318,12 +327,20 @@ def test_from_dict_preserves_dict() -> None: _ = Collection.from_dict(param_dict) assert param_dict == collection_dict + +@pytest.mark.xfail(reason="Dict is never mutated in pystac v2") +def test_from_dict_with_preserve_dict_False_does_not_preserve_dict() -> None: + path = TestCases.get_path("data-files/collections/with-assets.json") + with open(path) as f: + collection_dict = json.load(f) + param_dict = deepcopy(collection_dict) # assert that the parameter is not preserved with # non-default parameter _ = Collection.from_dict(param_dict, preserve_dict=False, migrate=False) assert param_dict != collection_dict +@pytest.mark.xfail(reason="Collection.from_dict does not set self href in pystac v2") def test_from_dict_set_root() -> None: path = TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") with open(path) as f: @@ -333,6 +350,7 @@ def test_from_dict_set_root() -> None: assert collection.get_root() is catalog +@pytest.mark.xfail(reason="Pystac v2 doesn't support summaries yet") def test_schema_summary() -> None: collection = pystac.Collection.from_file( TestCases.get_path( @@ -395,8 +413,9 @@ def test_temporal_extent_init_typing() -> None: _ = TemporalExtent([[start_datetime, end_datetime]]) +@pytest.mark.xfail(reason="pystac v2 uses strings to represent interval") @pytest.mark.block_network() -def test_temporal_extent_allows_single_interval() -> None: +def test_temporal_extent_represents_interval_as_datetimes() -> None: start_datetime = str_to_datetime("2022-01-01T00:00:00Z") end_datetime = str_to_datetime("2022-01-31T23:59:59Z") @@ -406,14 +425,28 @@ def test_temporal_extent_allows_single_interval() -> None: assert temporal_extent.intervals == [interval] +@pytest.mark.block_network() +def test_temporal_extent_allows_single_interval() -> None: + start = "2022-01-01T00:00:00Z" + end = "2022-01-31T23:59:59Z" + start_datetime = str_to_datetime(start) + end_datetime = str_to_datetime(end) + + interval = [start_datetime, end_datetime] + temporal_extent = TemporalExtent(intervals=interval) + + assert temporal_extent.intervals == [[start, end]] + + @pytest.mark.block_network() def test_temporal_extent_allows_single_interval_open_start() -> None: - end_datetime = str_to_datetime("2022-01-31T23:59:59Z") + end = "2022-01-31T23:59:59Z" + end_datetime = str_to_datetime(end) interval = [None, end_datetime] temporal_extent = TemporalExtent(intervals=interval) - assert temporal_extent.intervals == [interval] + assert temporal_extent.intervals == [[None, end]] @pytest.mark.block_network() @@ -494,8 +527,8 @@ def test_extent_from_items() -> None: assert len(extent.temporal.intervals) == 1 interval = extent.temporal.intervals[0] - assert interval[0] == datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) - assert interval[1] == datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + assert interval[0] == datetime_to_str(datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC)) + assert interval[1] == datetime_to_str(datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC)) def test_extent_to_from_dict() -> None: @@ -569,14 +602,14 @@ def test_clone(self) -> None: assert isinstance(cloned_collection, self.BasicCustomCollection) + @pytest.mark.xfail(reason="Pystac v2 requires that custom classes match method signatures") def test_collection_get_item_works(self) -> None: path = TestCases.get_path( "data-files/catalogs/test-case-1/country-1/area-1-1/collection.json" ) custom_collection = self.BasicCustomCollection.from_file(path) collection = custom_collection.clone() - with pytest.warns(DeprecationWarning): - collection.get_item("area-1-1-imagery") + collection.get_item("area-1-1-imagery") def test_collection_get_item_raises_type_error() -> None: @@ -711,8 +744,7 @@ def test_permissive_temporal_extent_deserialization(collection: Collection) -> N collection_dict["extent"]["temporal"]["interval"] = collection_dict["extent"][ "temporal" ]["interval"][0] - with pytest.warns(UserWarning): - Collection.from_dict(collection_dict) + Collection.from_dict(collection_dict) @pytest.mark.parametrize("fixture_name", ("sample_item_collection", "sample_items")) @@ -729,8 +761,8 @@ def test_from_items(fixture_name: str, request: pytest.FixtureRequest) -> None: start = collection.extent.temporal.intervals[0][0] end = collection.extent.temporal.intervals[0][1] - assert start and start <= str_to_datetime(item.properties["start_datetime"]) - assert end and end >= str_to_datetime(item.properties["end_datetime"]) + assert start and str_to_datetime(start) <= item.properties["start_datetime"] + assert end and str_to_datetime(end) >= item.properties["end_datetime"] if isinstance(items, ItemCollection): expected = {(link["rel"], link["href"]) for link in items.extra_fields["links"]} @@ -831,7 +863,7 @@ def test_from_dict_null_extent(collection: Collection) -> None: with pytest.warns(UserWarning): c = Collection.from_dict(d) - assert c.extent.spatial.to_dict()["bbox"] == [[-90, -180, 90, 180]] + assert c.extent.spatial.to_dict()["bbox"] == [[-180, -90, 180, 90]] assert c.extent.temporal.to_dict()["interval"] == [[None, None]] @@ -841,5 +873,5 @@ def test_from_dict_missing_extent(collection: Collection) -> None: with pytest.warns(UserWarning): c = Collection.from_dict(d) - assert c.extent.spatial.to_dict()["bbox"] == [[-90, -180, 90, 180]] + assert c.extent.spatial.to_dict()["bbox"] == [[-180, -90, 180, 90]] assert c.extent.temporal.to_dict()["interval"] == [[None, None]] diff --git a/tests/v1/test_common_metadata.py b/tests/v1/test_common_metadata.py index c2c0edfb5..74bb54f70 100644 --- a/tests/v1/test_common_metadata.py +++ b/tests/v1/test_common_metadata.py @@ -33,8 +33,8 @@ def test_datetimes(date_time_range_item: Item) -> None: # save dict of original item to check that `common_metadata` # method doesn't mutate self.item_1 before = date_time_range_item.clone().to_dict() - start_datetime_str = date_time_range_item.properties["start_datetime"] - assert isinstance(start_datetime_str, str) + start_datetime = date_time_range_item.properties["start_datetime"] + assert isinstance(start_datetime, datetime) cm = date_time_range_item.common_metadata assert isinstance(cm, CommonMetadata) @@ -43,6 +43,7 @@ def test_datetimes(date_time_range_item: Item) -> None: assert cm.providers is None +@pytest.mark.xfail(reason="pystac v2 purposefully converts datetimes to datetime objects") def test_common_metadata_start_datetime(date_time_range_item: Item) -> None: x = date_time_range_item.clone() start_datetime_str = "2018-01-01T13:21:30Z" @@ -59,6 +60,7 @@ def test_common_metadata_start_datetime(date_time_range_item: Item) -> None: assert x.properties["start_datetime"] == example_datetime_str +@pytest.mark.xfail(reason="pystac v2 purposefully converts datetimes to datetime objects") def test_common_metadata_end_datetime(date_time_range_item: Item) -> None: x = date_time_range_item.clone() end_datetime_str = "2018-01-01T13:31:30Z" @@ -75,6 +77,7 @@ def test_common_metadata_end_datetime(date_time_range_item: Item) -> None: assert x.properties["end_datetime"] == example_datetime_str +@pytest.mark.xfail(reason="pystac v2 purposefully converts datetimes to datetime objects") def test_common_metadata_created(sample_full_item: Item) -> None: x = sample_full_item.clone() created_str = "2016-05-04T00:00:01Z" @@ -91,6 +94,7 @@ def test_common_metadata_created(sample_full_item: Item) -> None: assert x.properties["created"] == example_datetime_str +@pytest.mark.xfail(reason="pystac v2 purposefully converts datetimes to datetime objects") def test_common_metadata_updated(sample_full_item: Item) -> None: x = sample_full_item.clone() updated_str = "2017-01-01T00:30:55Z" diff --git a/tests/v1/utils/test_cases.py b/tests/v1/utils/test_cases.py index 19324673b..f2bc08ad2 100644 --- a/tests/v1/utils/test_cases.py +++ b/tests/v1/utils/test_cases.py @@ -63,7 +63,7 @@ ] ARBITRARY_EXTENT = Extent( - spatial=SpatialExtent.from_coordinates(ARBITRARY_GEOM["coordinates"]), + spatial=SpatialExtent(ARBITRARY_BBOX), temporal=TemporalExtent.from_now(), )