diff --git a/autoware_ml/common/enums/enums.py b/autoware_ml/common/enums/enums.py index 16147c5..6d06f0c 100644 --- a/autoware_ml/common/enums/enums.py +++ b/autoware_ml/common/enums/enums.py @@ -15,6 +15,21 @@ from enum import Enum +class Modality(str, Enum): + """ + Modality. + + Attributes: + LIDAR: Lidar modality. + CAMERA: Camera modality. + RADAR: Radar modality. + """ + + LIDAR = "lidar" + CAMERA = "camera" + RADAR = "radar" + + class SplitType(str, Enum): """ Split type. diff --git a/autoware_ml/configs/database/t4dataset/t4dataset_base.yaml b/autoware_ml/configs/database/t4dataset/t4dataset_base.yaml index e1af26a..70042f1 100644 --- a/autoware_ml/configs/database/t4dataset/t4dataset_base.yaml +++ b/autoware_ml/configs/database/t4dataset/t4dataset_base.yaml @@ -18,5 +18,8 @@ defaults: - t4dataset/scenarios@scenarios.db_jpntaxi_base: detection3d/db_jpntaxi_base - t4dataset/scenarios@scenarios.db_largebus: detection3d/db_largebus +# Number of features in the lidar pointcloud +lidar_pointcloud_num_features: 5 + # Processor settings num_workers: 16 diff --git a/autoware_ml/configs/database/t4dataset/t4dataset_j6gen2_base.yaml b/autoware_ml/configs/database/t4dataset/t4dataset_j6gen2_base.yaml index 38b7f67..cb3ab31 100644 --- a/autoware_ml/configs/database/t4dataset/t4dataset_j6gen2_base.yaml +++ b/autoware_ml/configs/database/t4dataset/t4dataset_j6gen2_base.yaml @@ -16,5 +16,8 @@ defaults: - t4dataset/scenarios@scenarios.db_j6gen2: detection3d/db_j6gen2 - t4dataset/scenarios@scenarios.db_largebus: detection3d/db_largebus +# Number of features in the lidar pointcloud +lidar_pointcloud_num_features: 5 + # Processor settings num_workers: 16 diff --git a/autoware_ml/configs/database/t4dataset/t4dataset_jpntaxi_base.yaml b/autoware_ml/configs/database/t4dataset/t4dataset_jpntaxi_base.yaml index 333e7e2..940e275 100644 --- a/autoware_ml/configs/database/t4dataset/t4dataset_jpntaxi_base.yaml +++ b/autoware_ml/configs/database/t4dataset/t4dataset_jpntaxi_base.yaml @@ -14,5 +14,8 @@ cache_file_prefix_name: database defaults: - t4dataset/scenarios@scenarios.db_jpntaxi_base: detection3d/db_jpntaxi_base +# Number of features in the lidar pointcloud +lidar_pointcloud_num_features: 5 + # Processor settings num_workers: 16 diff --git a/autoware_ml/databases/base_database.py b/autoware_ml/databases/base_database.py index c772786..9645176 100644 --- a/autoware_ml/databases/base_database.py +++ b/autoware_ml/databases/base_database.py @@ -22,7 +22,7 @@ import polars as pl from autoware_ml.databases.scenarios import Scenarios, ScenarioData -from autoware_ml.databases.schemas import DatasetRecord, DatasetTableSchema +from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord, DatasetTableSchema logger = logging.getLogger(__name__) diff --git a/autoware_ml/databases/database_interface.py b/autoware_ml/databases/database_interface.py index 746bab4..22fdc78 100644 --- a/autoware_ml/databases/database_interface.py +++ b/autoware_ml/databases/database_interface.py @@ -19,7 +19,7 @@ from types import MappingProxyType from autoware_ml.databases.scenarios import Scenarios, ScenarioData -from autoware_ml.databases.schemas import DatasetRecord +from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord class DatabaseInterface(Protocol): diff --git a/autoware_ml/databases/schemas/__init__.py b/autoware_ml/databases/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/autoware_ml/databases/schemas/base_schemas.py b/autoware_ml/databases/schemas/base_schemas.py new file mode 100644 index 0000000..7d24660 --- /dev/null +++ b/autoware_ml/databases/schemas/base_schemas.py @@ -0,0 +1,87 @@ +# Copyright 2026 TIER IV, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import NamedTuple, Sequence, Mapping, Any + +import polars as pl + + +class DatasetTableColumn(NamedTuple): + """ + Annotation table column. + + Attributes: + name: Name of the column. + dtype: Data type of the column. + """ + + name: str + dtype: pl.DataType + + +class BaseFieldSchema: + """ + Base class for field schemas. + """ + + @classmethod + def to_polars_field_schema(cls) -> Sequence[pl.Field]: + """ + Convert the lidar column schema to a Polars field schema. + + Returns: + pl.Schema: Polars schema. + """ + + return [ + pl.Field(v.name, v.dtype) + for k, v in cls.__dict__.items() + if not k.startswith("__") and isinstance(v, DatasetTableColumn) + ] + + +class DataModelInterface(ABC): + """ + Interface for data models. + """ + + @abstractmethod + def to_dictionary(self) -> Mapping[str, Any]: + """ + Convert the data model to a dictionary. + + Returns: + Mapping[str, Any]: Dictionary representation of the data model. + """ + + raise NotImplementedError("Subclasses must implement to_dictionary!") + + @classmethod + def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> DataModelInterface: + """ + Load the data model and decode it to the corresponding data model from a dictionary, which is + deserialized from a Polars dataframe. + + Args: + data_model: Dictionary representation of the data model, which is + deserialized from a Polars dataframe. + + Returns: + DataModelInterface: Data model. + """ + + raise NotImplementedError("Subclasses must implement load_from_dictionary!") diff --git a/autoware_ml/databases/schemas/category_mapping.py b/autoware_ml/databases/schemas/category_mapping.py new file mode 100644 index 0000000..c45afc2 --- /dev/null +++ b/autoware_ml/databases/schemas/category_mapping.py @@ -0,0 +1,85 @@ +# Copyright 2026 TIER IV, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence, Any, Mapping + +import polars as pl +from pydantic import BaseModel, ConfigDict + +from autoware_ml.databases.schemas.base_schemas import ( + BaseFieldSchema, + DatasetTableColumn, + DataModelInterface, +) + + +@dataclass(frozen=True) +class CategoryMappingDatasetSchema(BaseFieldSchema): + """ + Dataclass to define polars schema for columns related to category mapping. + """ + + CATEGORY_NAMES = DatasetTableColumn("category_names", pl.List(pl.String)) + CATEGORY_INDICES = DatasetTableColumn("category_indices", pl.List(pl.Int32)) + + +class CategoryMappingDataModel(BaseModel, DataModelInterface): + """ + Category mapping data model that can be shared by multiple datasets. It saves the mapping + between category names and category indices. + + Attributes: + category_names: List of category names. + category_indices: List of category indices. + """ + + model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True) + + category_names: Sequence[str] + category_indices: Sequence[int] + + def model_post_init(self, __context: Any) -> None: + """Validate that all attributes are of the same length.""" + + assert len(self.category_names) == len(self.category_indices), ( + "All attributes must be of the same length" + ) + + def to_dictionary(self) -> Mapping[str, Any]: + """ + Convert the category mapping data model to a dictionary. + + Returns: + Mapping[str, Any]: Dictionary representation of the category mapping data model. + """ + + return self.model_dump() + + @classmethod + def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> CategoryMappingDataModel: + """ + Load the category mapping data model and decode it to the corresponding CategoryMappingDataModel + from a dictionary, which is deserialized from a Polars dataframe. + + Args: + data_model: Dictionary representation of the category mapping data model, which is + deserialized from a Polars dataframe. + """ + return cls( + category_names=data_model[CategoryMappingDatasetSchema.CATEGORY_NAMES.name], + category_indices=data_model[CategoryMappingDatasetSchema.CATEGORY_INDICES.name], + ) diff --git a/autoware_ml/databases/schemas/dataset_schemas.py b/autoware_ml/databases/schemas/dataset_schemas.py new file mode 100644 index 0000000..d120ca9 --- /dev/null +++ b/autoware_ml/databases/schemas/dataset_schemas.py @@ -0,0 +1,221 @@ +# Copyright 2026 TIER IV, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence, Mapping, Any + +import polars as pl +from pydantic import BaseModel, ConfigDict + +from autoware_ml.databases.schemas.base_schemas import DatasetTableColumn, DataModelInterface +from autoware_ml.databases.schemas.lidar_frames import LidarFrameDatasetSchema, LidarFrameDataModel +from autoware_ml.databases.schemas.category_mapping import ( + CategoryMappingDataModel, + CategoryMappingDatasetSchema, +) +from autoware_ml.databases.schemas.lidar_sources import ( + LidarSourceDatasetSchema, + LidarSourceDataModel, +) + + +@dataclass(frozen=True) +class DatasetTableSchema: + """ + Annotation table schema. + + Attributes: + SCENARIO_ID: Scenario ID column. + SAMPLE_ID: Sample ID column. + SAMPLE_INDEX: Sample index column. + LOCATION: Location column. + VEHICLE_TYPE: Vehicle type column. + SCENARIO_NAME: Scenario name column. + + # LiDAR Schema + LIDAR_FRAMES: Lidar frames colum, which is a list of dictionaries to save metadata of a lidar + frame. It also saves lidar sweeps as each item here. + + # Lidar Sources Schema + LIDAR_SOURCES: Lidar sources column, which is a list of dictionaries to save metadata about + each lidar sensor. + + # Category Schema + CATEGORY_MAPPING: Category mapping column, which is a dictionary to save the mapping between + category names and category indices. + """ + + # Basic Schema + SCENARIO_ID = DatasetTableColumn("scenario_id", pl.String) + SAMPLE_ID = DatasetTableColumn("sample_id", pl.String) + SAMPLE_INDEX = DatasetTableColumn("sample_index", pl.Int32) + TIMESTAMP_SECONDS = DatasetTableColumn("timestamp_seconds", pl.Float64) + LOCATION = DatasetTableColumn("location", pl.String) + VEHICLE_TYPE = DatasetTableColumn("vehicle_type", pl.String) + SCENARIO_NAME = DatasetTableColumn("scenario_name", pl.String) + + # LiDAR Frames Schema + LIDAR_FRAMES = DatasetTableColumn( + "lidar_frames", pl.List(pl.Struct(LidarFrameDatasetSchema.to_polars_field_schema())) + ) + + # LiDAR Sources Schema + LIDAR_SOURCES = DatasetTableColumn( + "lidar_sources", pl.List(pl.Struct(LidarSourceDatasetSchema.to_polars_field_schema())) + ) + + # Category Schema + CATEGORY_MAPPING = DatasetTableColumn( + "category_mapping", + pl.Struct(CategoryMappingDatasetSchema.to_polars_field_schema()), + ) + + @classmethod + def to_polars_schema(cls) -> pl.Schema: + """ + Convert the dataset table schema to a Polars schema. + + Returns: + pl.Schema: Polars schema. + """ + + return pl.Schema( + { + v.name: v.dtype + for k, v in cls.__dict__.items() + if not k.startswith("__") and isinstance(v, DatasetTableColumn) + } + ) + + +class DatasetRecord(BaseModel, DataModelInterface): + """ + Data class to save a record for each column in the annotation table. + + Attributes: + # Basic Metadata + scenario_id: Scenario ID. + sample_id: Sample ID. + sample_index: Sample index. + location: Location of the vehicle. + vehicle_type: Type of the vehicle. + + # LiDAR frame data + lidar_frames: List of lidar frame data models, including multisweep lidar frames. + + # Lidar sources data + lidar_sources: List of lidar source data models. + + # Category data + category_mapping: Category mapping data model. + """ + + # Set model config to frozen + model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True) + + # Basic Dataset Record + scenario_id: str + sample_id: str + sample_index: int + timestamp_seconds: float + location: str + vehicle_type: str + scenario_name: str + + lidar_frames: Sequence[LidarFrameDataModel] + lidar_sources: Sequence[LidarSourceDataModel] | None + category_mapping: CategoryMappingDataModel | None + + def to_dictionary(self) -> Mapping[str, Any]: + """ + Convert the dataset record to a dictionary. + + Returns: + Mapping[str, Any]: Dictionary representation of the dataset record. + """ + data_model = { + DatasetTableSchema.SCENARIO_ID.name: self.scenario_id, + DatasetTableSchema.SAMPLE_ID.name: self.sample_id, + DatasetTableSchema.SAMPLE_INDEX.name: self.sample_index, + DatasetTableSchema.TIMESTAMP_SECONDS.name: self.timestamp_seconds, + DatasetTableSchema.LOCATION.name: self.location, + DatasetTableSchema.VEHICLE_TYPE.name: self.vehicle_type, + DatasetTableSchema.SCENARIO_NAME.name: self.scenario_name, + } + data_model[DatasetTableSchema.LIDAR_FRAMES.name] = [ + lidar_frame.to_dictionary() for lidar_frame in self.lidar_frames + ] + + if self.lidar_sources: + data_model[DatasetTableSchema.LIDAR_SOURCES.name] = [ + lidar_source.to_dictionary() for lidar_source in self.lidar_sources + ] + else: + data_model[DatasetTableSchema.LIDAR_SOURCES.name] = None + + if self.category_mapping: + data_model[DatasetTableSchema.CATEGORY_MAPPING.name] = ( + self.category_mapping.to_dictionary() + ) + else: + data_model[DatasetTableSchema.CATEGORY_MAPPING.name] = None + + return data_model + + @classmethod + def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> DatasetRecord: + """ + Load the dataset record from a Polars dataframe. + + Args: + data_model: Dictionary representation of the dataset record, which is + deserialized from a Polars dataframe. + + Returns: + DatasetRecord: Data model of the dataset record. + """ + lidar_frames = data_model[DatasetTableSchema.LIDAR_FRAMES.name] + lidar_frames = [ + LidarFrameDataModel.load_from_dictionary(lidar_frame) for lidar_frame in lidar_frames + ] + + lidar_sources = data_model[DatasetTableSchema.LIDAR_SOURCES.name] + if lidar_sources is not None: + lidar_sources = [ + LidarSourceDataModel.load_from_dictionary(lidar_source) + for lidar_source in lidar_sources + ] + else: + lidar_sources = None + + category_mapping = data_model[DatasetTableSchema.CATEGORY_MAPPING.name] + if category_mapping is not None: + category_mapping = CategoryMappingDataModel.load_from_dictionary(category_mapping) + else: + category_mapping = None + + return cls( + scenario_id=data_model[DatasetTableSchema.SCENARIO_ID.name], + sample_id=data_model[DatasetTableSchema.SAMPLE_ID.name], + sample_index=data_model[DatasetTableSchema.SAMPLE_INDEX.name], + timestamp_seconds=data_model[DatasetTableSchema.TIMESTAMP_SECONDS.name], + location=data_model[DatasetTableSchema.LOCATION.name], + vehicle_type=data_model[DatasetTableSchema.VEHICLE_TYPE.name], + scenario_name=data_model[DatasetTableSchema.SCENARIO_NAME.name], + lidar_frames=lidar_frames, + lidar_sources=lidar_sources, + category_mapping=category_mapping, + ) diff --git a/autoware_ml/databases/schemas/frame_basic_metadata.py b/autoware_ml/databases/schemas/frame_basic_metadata.py new file mode 100644 index 0000000..941c0e5 --- /dev/null +++ b/autoware_ml/databases/schemas/frame_basic_metadata.py @@ -0,0 +1,40 @@ +# Copyright 2026 TIER IV, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel, ConfigDict + + +class FrameBasicMetadata(BaseModel): + """ + Basic metadata for a frame/record that can be shared by multiple datasets. + + Attributes: + scenario_id: Scenario ID. + sample_id: Sample ID. + sample_index: Sample index. + timestamp_seconds: Timestamp in seconds. + scenario_name: Scenario name. + location: Location. + vehicle_type: Vehicle type. + """ + + model_config = ConfigDict(frozen=True, strict=True) + + scenario_id: str + sample_id: str + sample_index: int + timestamp_seconds: float + scenario_name: str + location: str | None + vehicle_type: str | None diff --git a/autoware_ml/databases/schemas/lidar_frames.py b/autoware_ml/databases/schemas/lidar_frames.py new file mode 100644 index 0000000..733fc70 --- /dev/null +++ b/autoware_ml/databases/schemas/lidar_frames.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, Any + +import numpy as np +import numpy.typing as npt +import polars as pl +from pydantic import BaseModel, ConfigDict + +from autoware_ml.databases.schemas.base_schemas import ( + BaseFieldSchema, + DatasetTableColumn, + DataModelInterface, +) + + +@dataclass(frozen=True) +class LidarFrameDatasetSchema(BaseFieldSchema): + """ + Dataclass to define polars schema for columns related to lidar. + """ + + lidar_frame_id = DatasetTableColumn("lidar_frame_id", pl.String) + lidar_keyframe = DatasetTableColumn("lidar_keyframe", pl.Boolean) + lidar_sensor_id = DatasetTableColumn("lidar_sensor_id", pl.String) + lidar_timestamp_seconds = DatasetTableColumn("lidar_timestamp_seconds", pl.Float64) + lidar_sensor_channel_name = DatasetTableColumn("lidar_sensor_channel_name", pl.String) + lidar_pointcloud_path = DatasetTableColumn("lidar_pointcloud_path", pl.String) + lidar_pointcloud_source_path = DatasetTableColumn("lidar_pointcloud_source_path", pl.String) + lidar_pointcloud_num_features = DatasetTableColumn("lidar_pointcloud_num_features", pl.Int32) + lidar_sensor_to_ego_pose_matrix = DatasetTableColumn( + "lidar_sensor_to_ego_pose_matrix", pl.Array(pl.Float32, shape=(4, 4)) + ) + lidar_frame_ego_pose_to_global_matrix = DatasetTableColumn( + "lidar_frame_ego_pose_to_global_matrix", pl.Array(pl.Float32, shape=(4, 4)) + ) + lidar_sensor_to_lidar_sweep_matrices = DatasetTableColumn( + "lidar_sensor_to_lidar_sweep_matrices", pl.Array(pl.Float32, shape=(4, 4)) + ) + lidar_pointcloud_semantic_mask_path = DatasetTableColumn( + "lidar_pointcloud_semantic_mask_path", pl.String + ) + + +class LidarFrameDataModel(BaseModel, DataModelInterface): + """ + Lidar frame data model that can be shared by multiple datasets. It saves the metadata of a lidar + frame. Note that lidar sweeps also use this data model. + + Attributes: + lidar_frame_id: Lidar frame ID. + lidar_keyframe: Whether this lidar frame is a keyframe. Set to True if it's a keyframe, + otherwise, it is a sweep frame. + lidar_sensor_id: Lidar sensor ID. + lidar_sensor_channel_name: Lidar sensor channel name. + lidar_timestamp_seconds: Lidar timestamp in seconds. + lidar_pointcloud_path: Lidar pointcloud path. + lidar_pointcloud_source_path: Lidar pointcloud source path, which is the path to the + information for each lidar pointcloud. Set to None if it's not available. + lidar_pointcloud_num_features: Lidar pointcloud num features. + lidar_sensor_to_ego_pose_matrix: Transformation matrix from the lidar sensor of this frame to + the ego pose of this lidar frame. + lidar_frame_ego_pose_to_global_matrix: Transformation matrix from the ego pose of this lidar + frame to the global frame. + lidar_sensor_to_lidar_sweep_matrices: Transformation matrices from the main lidar sensor + to other lidar sweeps at this frame. + lidar_pointcloud_semantic_mask_path: Lidar pointcloud semantic mask path. Set to None if it's + not available. + """ + + model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True) + + lidar_frame_id: str + lidar_keyframe: bool + lidar_sensor_id: str + lidar_sensor_channel_name: str + lidar_timestamp_seconds: float + lidar_pointcloud_path: str + lidar_pointcloud_source_path: str | None + lidar_pointcloud_num_features: int + lidar_sensor_to_ego_pose_matrix: npt.NDArray[np.float64] # (4, 4) + # Transformation matrix from the ego pose of this lidar frame to the global frame. + lidar_frame_ego_pose_to_global_matrix: npt.NDArray[np.float64] # (4, 4) + # Transformation matrices from the main lidar sensor to other lidar sweeps at this frame. + lidar_sensor_to_lidar_sweep_matrices: npt.NDArray[np.float64] # (4, 4) + lidar_pointcloud_semantic_mask_path: str | None + + @property + def lidar_pointcloud_relative_path(self: str) -> str: + """ + Parse lidar pointcloud path to {database_version}/{scene_id}/ + {dataset_version}/data/{lidar_token}/{frame}.bin from path. + + Returns: + str: Lidar pointcloud relative path. + """ + + return "/".join(self.lidar_pointcloud_path.split("/")[-6:]) + + @property + def lidar_pointcloud_source_relative_path(self: str) -> str | None: + """ + Parse lidar pointcloud source path to {database_version}/{scene_id}/ + {dataset_version}/data/{lidar_token}/{frame}.bin from path. + + Returns: + str | None: Lidar pointcloud source relative path. + """ + if self.lidar_pointcloud_source_path is None: + return None + + return "/".join(self.lidar_pointcloud_source_path.split("/")[-6:]) + + @property + def lidarseg_pointcloud_semantic_mask_relative_path(self: str) -> str | None: + """ + Parse lidarseg pts semantic mask path to {database_version}/{scene_id}/ + {dataset_version}/data/{lidar_token}/{frame}.bin from path. + """ + if self.lidar_pointcloud_semantic_mask_path is None: + return None + + return "/".join(self.lidar_pointcloud_semantic_mask_path.split("/")[-6:]) + + @property + def lidar_sensor_to_ego_pose_matrix_fp32(self) -> npt.NDArray[np.float32]: + """ + Convert the lidar sensor to ego pose matrix to float32. + + Returns: + npt.NDArray[np.float32]: Lidar sensor to ego pose matrix. + """ + + return self.lidar_sensor_to_ego_pose_matrix.astype(np.float32) + + @property + def lidar_frame_ego_pose_to_global_matrix_fp32(self) -> npt.NDArray[np.float32]: + """ + Convert the lidar frame ego pose to global matrix to float32. + + Returns: + npt.NDArray[np.float32]: Lidar frame ego pose to global matrix. + """ + + return self.lidar_frame_ego_pose_to_global_matrix.astype(np.float32) + + @property + def lidar_sensor_to_lidar_sweep_matrices_fp32(self) -> npt.NDArray[np.float32]: + """ + Convert the lidar sensor to lidar sweep matrices to float32. + + Returns: + npt.NDArray[np.float32] | None: Lidar sensor to lidar sweep matrices. + """ + + return self.lidar_sensor_to_lidar_sweep_matrices.astype(np.float32) + + def to_dictionary(self) -> Mapping[str, Any]: + """ + Convert the lidar frame data model to a dictionary. + + Args: + to_fp32: Whether to convert the lidar frame data model to float32. + + Returns: + Mapping[str, Any]: Dictionary representation of the lidar frame data model. + """ + + return { + LidarFrameDatasetSchema.lidar_frame_id.name: self.lidar_frame_id, + LidarFrameDatasetSchema.lidar_keyframe.name: self.lidar_keyframe, + LidarFrameDatasetSchema.lidar_sensor_id.name: self.lidar_sensor_id, + LidarFrameDatasetSchema.lidar_timestamp_seconds.name: self.lidar_timestamp_seconds, + LidarFrameDatasetSchema.lidar_sensor_channel_name.name: self.lidar_sensor_channel_name, + LidarFrameDatasetSchema.lidar_pointcloud_path.name: self.lidar_pointcloud_path, + LidarFrameDatasetSchema.lidar_pointcloud_source_path.name: self.lidar_pointcloud_source_path, + LidarFrameDatasetSchema.lidar_pointcloud_num_features.name: self.lidar_pointcloud_num_features, + LidarFrameDatasetSchema.lidar_sensor_to_ego_pose_matrix.name: self.lidar_sensor_to_ego_pose_matrix_fp32, + LidarFrameDatasetSchema.lidar_frame_ego_pose_to_global_matrix.name: self.lidar_frame_ego_pose_to_global_matrix_fp32, + LidarFrameDatasetSchema.lidar_sensor_to_lidar_sweep_matrices.name: self.lidar_sensor_to_lidar_sweep_matrices_fp32, + LidarFrameDatasetSchema.lidar_pointcloud_semantic_mask_path.name: self.lidar_pointcloud_semantic_mask_path, + } + + @classmethod + def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> LidarFrameDataModel: + """ + Load the lidar frame data model and decode it to the corresponding LidarFrameDataModel + from a dictionary, which is deserialized from a Polars dataframe. + + Args: + data_model: Dictionary representation of the lidar frame data model, which is + deserialized from a Polars dataframe. + + Returns: + LidarFrameDataModel: LidarFrameDataModel object. + """ + + return cls( + lidar_frame_id=data_model[LidarFrameDatasetSchema.lidar_frame_id.name], + lidar_keyframe=data_model[LidarFrameDatasetSchema.lidar_keyframe.name], + lidar_sensor_id=data_model[LidarFrameDatasetSchema.lidar_sensor_id.name], + lidar_timestamp_seconds=data_model[ + LidarFrameDatasetSchema.lidar_timestamp_seconds.name + ], + lidar_sensor_channel_name=data_model[ + LidarFrameDatasetSchema.lidar_sensor_channel_name.name + ], + lidar_pointcloud_path=data_model[LidarFrameDatasetSchema.lidar_pointcloud_path.name], + lidar_pointcloud_source_path=data_model[ + LidarFrameDatasetSchema.lidar_pointcloud_source_path.name + ], + lidar_pointcloud_num_features=data_model[ + LidarFrameDatasetSchema.lidar_pointcloud_num_features.name + ], + lidar_sensor_to_ego_pose_matrix=data_model[ + LidarFrameDatasetSchema.lidar_sensor_to_ego_pose_matrix.name + ], + lidar_frame_ego_pose_to_global_matrix=data_model[ + LidarFrameDatasetSchema.lidar_frame_ego_pose_to_global_matrix.name + ], + lidar_sensor_to_lidar_sweep_matrices=data_model[ + LidarFrameDatasetSchema.lidar_sensor_to_lidar_sweep_matrices.name + ], + lidar_pointcloud_semantic_mask_path=data_model[ + LidarFrameDatasetSchema.lidar_pointcloud_semantic_mask_path.name + ], + ) diff --git a/autoware_ml/databases/schemas/lidar_sources.py b/autoware_ml/databases/schemas/lidar_sources.py new file mode 100644 index 0000000..3c4ba27 --- /dev/null +++ b/autoware_ml/databases/schemas/lidar_sources.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, Any + +import numpy as np +import numpy.typing as npt +import polars as pl +from pydantic import BaseModel, ConfigDict + +from autoware_ml.databases.schemas.base_schemas import ( + BaseFieldSchema, + DatasetTableColumn, + DataModelInterface, +) + + +@dataclass(frozen=True) +class LidarSourceDatasetSchema(BaseFieldSchema): + """ + Dataclass to define polars schema for columns related to lidar pointcloud. + """ + + channel_name = DatasetTableColumn("channel_name", pl.String) + sensor_token = DatasetTableColumn("sensor_token", pl.String) + translation = DatasetTableColumn("translation", pl.Array(pl.Float32, shape=(3,))) + rotation = DatasetTableColumn("rotation", pl.Array(pl.Float32, shape=(4,))) + + +class LidarSourceDataModel(BaseModel, DataModelInterface): + """ + Lidar source data model that can be shared by multiple datasets. + + Attributes: + channel_name: Lidar source channel name. + sensor_token: Lidar source sensor token. + translation: Lidar source translation (3, ). + rotation: Lidar source rotation (4, ). + """ + + model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True) + + channel_name: str + sensor_token: str + translation: npt.NDArray[np.float64] + rotation: npt.NDArray[np.float64] + + @property + def translation_fp32(self) -> npt.NDArray[np.float32]: + """ + Convert the lidar source translations to float32. + + Returns: + npt.NDArray[np.float32]: Lidar source translation. + """ + + return self.translation.astype(np.float32) + + @property + def rotation_fp32(self) -> npt.NDArray[np.float32]: + """ + Convert the lidar source rotations to float32. + + Returns: + npt.NDArray[np.float32]: Lidar source rotation. + """ + + return self.rotation.astype(np.float32) + + def to_dictionary(self) -> Mapping[str, Any]: + """ + Convert the lidar source data model to a dictionary. + + Returns: + Mapping[str, Any]: Dictionary representation of the lidar source data model. + """ + + return { + LidarSourceDatasetSchema.channel_name.name: self.channel_name, + LidarSourceDatasetSchema.sensor_token.name: self.sensor_token, + LidarSourceDatasetSchema.translation.name: self.translation_fp32, + LidarSourceDatasetSchema.rotation.name: self.rotation_fp32, + } + + @classmethod + def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> LidarSourceDataModel: + """ + Load the lidar source data model and decode it to the corresponding LidarSourceDataModel + from a dictionary, which is deserialized from a Polars dataframe. + + Args: + data_model: Dictionary representation of the lidar source data model, which is + deserialized from a Polars dataframe. + + Returns: + LidarSourceDataModel: LidarSourceDataModel object. + """ + + return cls( + channel_name=data_model[LidarSourceDatasetSchema.channel_name.name], + sensor_token=data_model[LidarSourceDatasetSchema.sensor_token.name], + translation=data_model[LidarSourceDatasetSchema.translation.name], + rotation=data_model[LidarSourceDatasetSchema.rotation.name], + ) diff --git a/autoware_ml/databases/t4dataset/t4dataset.py b/autoware_ml/databases/t4dataset/t4dataset.py index 4fe30d1..5048054 100644 --- a/autoware_ml/databases/t4dataset/t4dataset.py +++ b/autoware_ml/databases/t4dataset/t4dataset.py @@ -25,12 +25,12 @@ import polars as pl from tqdm import tqdm -from autoware_ml.databases.database_interface import DatabaseInterface from autoware_ml.databases.base_database import BaseDatabase -from autoware_ml.databases.t4dataset.t4scenarios import T4Scenarios +from autoware_ml.databases.database_interface import DatabaseInterface from autoware_ml.databases.scenarios import ScenarioData -from autoware_ml.databases.schemas import DatasetRecord +from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord from autoware_ml.databases.t4dataset.t4records_generator import T4RecordsGenerator +from autoware_ml.databases.t4dataset.t4scenarios import T4Scenarios logger = logging.getLogger(__name__) @@ -45,10 +45,12 @@ class T4RecordsGeneratorWorkerParams: database_root_path: Root path of the T4 database. dataset_version: Version of the dataset. scenario_data: Scenario data. + lidar_pointcloud_num_features: Number of features in the lidar pointcloud. """ database_root_path: str scenario_data: ScenarioData + lidar_pointcloud_num_features: int def _apply_t4_records_generator( @@ -69,6 +71,7 @@ def _apply_t4_records_generator( scenario_data=t4_records_generator_worker_params.scenario_data, sample_steps=t4_records_generator_worker_params.scenario_data.sample_steps, max_sweeps=t4_records_generator_worker_params.scenario_data.max_sweeps, + lidar_pointcloud_num_features=t4_records_generator_worker_params.lidar_pointcloud_num_features, ) # Generate DatasetRecords return t4_records_generator.generate_dataset_records() @@ -85,6 +88,7 @@ def __init__( cache_path: str, cache_file_prefix_name: str, num_workers: int, + lidar_pointcloud_num_features: int, ) -> None: """ Initialize T4 dataset. Please refer to the BaseDatabase class for more details. @@ -96,6 +100,7 @@ def __init__( cache_path: Path to cache the dataset records. cache_file_prefix_name: Prefix name of the cache file, it will be _.parquet num_workers: Number of workers to use for processing the dataset. + lidar_pointcloud_num_features: Number of features in the lidar pointcloud. """ logger.info("Initializing T4 dataset...") @@ -107,6 +112,7 @@ def __init__( num_workers=num_workers, ) self._scenarios = scenarios + self._lidar_pointcloud_num_features = lidar_pointcloud_num_features def __str__(self) -> str: """ @@ -151,24 +157,26 @@ def process_scenario_records(self) -> Sequence[DatasetRecord]: # TODO (KokSeang): Read the cache if it exists, and return the records - # First, read all unique scenario data + # 1) Read all unique scenario data unique_scenario_data = self.get_unique_scenario_data() logger.info( f"Processing a total of {len(unique_scenario_data)} unique scenarios in T4Dataset" ) - # Second, send the list to the multiprocessing or single processing the scenario + # 2) Send the list to the multiprocessing or single processing the scenario # samples/frames scenario_sample_records = self._run_t4records_generator(unique_scenario_data) logger.info(f"Processed {len(scenario_sample_records)} scenario sample records") - # Third, get the polar schema + # 3) Save the scenario sample records to a polars .parquet file + # Dump to a list of dictionaries to make it safer since it's using Pydantic.BaseModel + scenario_sample_records = [record.to_dictionary() for record in scenario_sample_records] + + # 4) Get the polar schema polars_schema = self.get_polars_schema() logger.info(f"Parquet schema: {polars_schema}") - # Fourth, save the scenario sample records to a polars .parquet file - # Dump to a list of dictionaries to make it safer since it's using Pydantic.BaseModel - scenario_sample_records = [record.model_dump() for record in scenario_sample_records] + # 5) Save the scenario sample records to a polars .parquet file df = pl.DataFrame(scenario_sample_records, schema=polars_schema) df_hash = hashlib.sha256(str(self).encode("utf-8")).hexdigest() df_cache_path = self._cache_path / f"{self._cache_file_prefix_name}_{df_hash}.parquet" @@ -201,6 +209,7 @@ def _run_t4records_generator( T4RecordsGeneratorWorkerParams( database_root_path=self._database_root_path, scenario_data=scenario, + lidar_pointcloud_num_features=self._lidar_pointcloud_num_features, ) for scenario in scenario_data.values() ] diff --git a/autoware_ml/databases/t4dataset/t4records_generator.py b/autoware_ml/databases/t4dataset/t4records_generator.py index a11702e..bd53505 100644 --- a/autoware_ml/databases/t4dataset/t4records_generator.py +++ b/autoware_ml/databases/t4dataset/t4records_generator.py @@ -15,71 +15,54 @@ import logging from pathlib import Path -from typing import Sequence +from typing import Sequence, Tuple -from pydantic import BaseModel, ConfigDict +import numpy as np +import numpy.typing as npt from t4_devkit import Tier4 -from t4_devkit.schema import Sample, SampleData, CalibratedSensor -from t4_devkit.typing import Quaternion, Vector3 +from t4_devkit.schema import ( + CalibratedSensor, + EgoPose, + LidarSeg, + Sample, + SampleData, + Scene, + Sensor, + SchemaName, +) +from t4_devkit.common.timestamp import microseconds2seconds -from autoware_ml.common.enums.enums import LidarChannel -from autoware_ml.databases.schemas import DatasetRecord -from autoware_ml.databases.scenarios import ScenarioData - -logger = logging.getLogger(__name__) - - -class T4SampleRecord(BaseModel): - """ - Temporary T4 sample record. - Attributes: - scenario_id: Scenario ID. - sample_id: Sample ID. - sample_index: Sample index. - lidar_path: Lidar path. - location: Location. - vehicle_type: Vehicle type. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True) - - scenario_id: str - sample_id: str - sample_index: int - lidar_path: str - location: str | None - vehicle_type: str | None - - # Lidar to ego transformation - lidar2ego_translation: Vector3 - lidar2ego_rotation: Quaternion +from autoware_ml.common.enums.enums import LidarChannel, Modality +from autoware_ml.databases.schemas.frame_basic_metadata import FrameBasicMetadata +from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord +from autoware_ml.databases.schemas.lidar_frames import LidarFrameDataModel +from autoware_ml.databases.schemas.lidar_sources import LidarSourceDataModel +from autoware_ml.databases.schemas.category_mapping import CategoryMappingDataModel +from autoware_ml.databases.scenarios import ScenarioData +from autoware_ml.databases.t4dataset.t4sample_records import ( + T4SampleRecord, +) +from autoware_ml.utils.dataset import convert_quaternion_to_matrix - def to_dataset_record(self) -> DatasetRecord: - """ - Convert T4 sample record to dataset record. - Returns: - DatasetRecord: Dataset record. - """ - return DatasetRecord( - scenario_id=self.scenario_id, - sample_id=self.sample_id, - sample_index=self.sample_index, - location=self.location, - vehicle_type=self.vehicle_type, - ) +logger = logging.getLogger(__name__) class T4RecordsGenerator: """RecordsGenerator for T4Dataset.""" + __MODALITY_STRING = "modality" + __VALUE_STRING = "value" + __IGNORE_LABEL_INDEX = -1 + def __init__( self, database_root_path: str, scenario_data: ScenarioData, max_sweeps: int, sample_steps: int, + lidar_pointcloud_num_features: int, ) -> None: """ Initialize T4RecordsGenerator. @@ -91,12 +74,21 @@ def __init__( if skipping lidar sweep concatenation. sample_steps: Number of frames/samples to skip between each sample, set to 1 if not skipping any samples/frames. + lidar_pointcloud_num_features: Number of features of the lidar pointcloud. + label_remapping (MappingProxyType[str, int]): Remapping of the label names to another + label name. + filter_attributes (MappingProxyType[str, Sequence[str]]): 3D bounding boxes with the + class names and selected attributes in the filter_attributes will be filtered out. + merge_objects (MappingProxyType[str, Sequence[str, str]]): Mapping of the target labels + to the source labels to merge the 3D bounding boxes. + """ self.database_root_path = Path(database_root_path) self.scenario_data = scenario_data self.max_sweeps = max_sweeps self.sample_steps = sample_steps + self.lidar_pointcloud_num_features = lidar_pointcloud_num_features self.t4_devkit_dataset = self._construct_t4_devkit_dataset() assert sample_steps > 0, "Sample steps must be greater than 0." @@ -132,14 +124,360 @@ def generate_dataset_records(self) -> Sequence[DatasetRecord]: logger.info( f"Generating dataset records for scenario: {self.scenario_data.scenario_id} with sample steps: {self.sample_steps} and max sweeps: {self.max_sweeps}" ) + for sample_index in range(0, len(self.t4_devkit_dataset.sample), self.sample_steps): sample = self.t4_devkit_dataset.sample[sample_index] t4_sample_record = self.extract_t4_sample_record(sample, sample_index) + + if t4_sample_record is None: + logger.info( + f"dataset_name: {self.scenario_data.dataset_name}, " + f"scenario_id: {self.scenario_data.scenario_id}, " + f"sample_index: {sample_index}, " + f"No lidar channel found in sample data" + ) + continue + records.append(t4_sample_record.to_dataset_record()) return records - def extract_t4_sample_record(self, sample: Sample, sample_index: int) -> T4SampleRecord: + def _extract_sample_basic_metadata( + self, sample: Sample, sample_index: int + ) -> FrameBasicMetadata: + """ + Extract basic metadata from a T4 sample. + + Args: + sample: T4 Sample. + sample_index: Sample index. + + Returns: + FrameBasicMetadata: Frame basic metadata of the T4 sample. + """ + + scene_record: Scene = self.t4_devkit_dataset.get(SchemaName.SCENE, sample.scene_token) + return FrameBasicMetadata( + scenario_id=self.scenario_data.scenario_id, + sample_id=sample.token, + sample_index=sample_index, + location=self.scenario_data.location, + vehicle_type=self.scenario_data.vehicle_type, + timestamp_seconds=microseconds2seconds(sample.timestamp), + scenario_name=scene_record.name, + ) + + def _extract_lidar_pointcloud_semantic_mask_path( + self, + sample_index: int, + calibrated_lidar_sample_data_token: str, + lidar_pointcloud_source_path: str | None, + ) -> str | None: + """ + Extract lidarseg metadata from a T4 Sample. + + Args: + sample_index: Sample index. + calibrated_lidar_sample_data_token: Calibrated lidar sample data token. + lidar_pointcloud_source_path: Lidar pointcloud source path. + + Returns: + LidarSegMetaData: Lidarseg metadata of the T4 sample. + """ + lidarseg_records: Sequence[LidarSeg] = getattr( + self.t4_devkit_dataset, SchemaName.LIDARSEG, [] + ) + # If there are no lidarseg records or the lidar pointcloud source path is not available, + # return None + if not len(lidarseg_records) or not lidar_pointcloud_source_path: + return None + + assert sample_index < len(lidarseg_records), ( + "Sample index is out of range of lidarseg records." + ) + + current_lidarseg_record = lidarseg_records[sample_index] + assert current_lidarseg_record.sample_data_token == calibrated_lidar_sample_data_token, ( + "Lidarseg record sample data token does not match the calibrated lidar sample data token." + ) + return current_lidarseg_record.filename + + def _extract_lidar_frame( + self, sample: Sample, sample_index: int, lidar_channel_name: str + ) -> LidarFrameDataModel: + """ + Extract lidar frame records from a T4 sample. + + Args: + sample: T4 Sample. + lidar_channel_name: Lidar channel name. + + Returns: + LidarDatasetRecord: Lidar records of the T4 sample. + """ + + calibrated_lidar_sample_data_token = sample.data[lidar_channel_name] + sd_record: SampleData = self.t4_devkit_dataset.get( + SchemaName.SAMPLE_DATA, calibrated_lidar_sample_data_token + ) + cs_record: CalibratedSensor = self.t4_devkit_dataset.get( + SchemaName.CALIBRATED_SENSOR, sd_record.calibrated_sensor_token + ) + lidar_sensor_to_ego_matrix = convert_quaternion_to_matrix( + rotation_quaternion=cs_record.rotation, + translation=cs_record.translation, + convert_to_float32=False, + ) + + lidar_path, _, _ = self.t4_devkit_dataset.get_sample_data( + sample_data_token=calibrated_lidar_sample_data_token, + as_3d=True, + as_sensor_coord=True, + ) + + # Extract ego pose to global matrix in the lidar frame from the T4Dataset + ego_pose_record: EgoPose = self.t4_devkit_dataset.get( + SchemaName.EGO_POSE, sd_record.ego_pose_token + ) + lidar_frame_ego_pose_to_global_matrix = convert_quaternion_to_matrix( + rotation_quaternion=ego_pose_record.rotation, + translation=ego_pose_record.translation, + convert_to_float32=False, + ) + + # Etxract lidar pointcloud semantic mask path + lidar_pointcloud_semantic_mask_path = self._extract_lidar_pointcloud_semantic_mask_path( + sample_index=sample_index, + calibrated_lidar_sample_data_token=calibrated_lidar_sample_data_token, + lidar_pointcloud_source_path=sd_record.info_filename, + ) + + return LidarFrameDataModel( + lidar_frame_id=calibrated_lidar_sample_data_token, + lidar_keyframe=True, + lidar_sensor_id=cs_record.token, + lidar_sensor_channel_name=lidar_channel_name, + lidar_timestamp_seconds=microseconds2seconds(sd_record.timestamp), + lidar_pointcloud_path=lidar_path, + lidar_pointcloud_source_path=sd_record.info_filename, + lidar_pointcloud_num_features=self.lidar_pointcloud_num_features, + lidar_sensor_to_ego_pose_matrix=lidar_sensor_to_ego_matrix, + lidar_frame_ego_pose_to_global_matrix=lidar_frame_ego_pose_to_global_matrix, + lidar_sensor_to_lidar_sweep_matrices=np.eye( + 4 + ), # Always the identity matrix for the main lidar sensor + lidar_pointcloud_semantic_mask_path=lidar_pointcloud_semantic_mask_path, + ) + + def _compute_sensor_transformation_matrices( + self, + sensor_sample_data_record: SampleData, + selected_sensor_to_ego_pose_matrix: npt.NDArray[np.float64], + selected_sensor_frame_ego_pose_to_global_matrix: npt.NDArray[np.float64], + ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """ + Compute transformation matrices for a sensor. + + Args: + sensor_sample_data_record: Sample data record of the sensor. + selected_sensor_to_ego_pose_matrix: Transformation matrix from the selected + sensor to its' the ego pose. + selected_sensor_frame_ego_pose_to_global_matrix: Transformation matrix from the selected + sensor frame ego pose to the global frame. + + Returns: + Tuple of transformation matrices: + 1. Sensor frame ego pose to global matrix (4x4) + 2. Selected sensor to sensor transformation matrix (4x4) + """ + sensor_calibrated_sensor_record: CalibratedSensor = self.t4_devkit_dataset.get( + SchemaName.CALIBRATED_SENSOR, sensor_sample_data_record.calibrated_sensor_token + ) + sensor_ego_pose_record: EgoPose = self.t4_devkit_dataset.get( + SchemaName.EGO_POSE, sensor_sample_data_record.ego_pose_token + ) + + sensor_to_ego_pose_tranlation = sensor_calibrated_sensor_record.translation + sensor_to_ego_pose_rotation = sensor_calibrated_sensor_record.rotation + + sensor_frame_ego_pose_to_global_translation = sensor_ego_pose_record.translation + sensor_frame_ego_pose_to_global_rotation = sensor_ego_pose_record.rotation + + sensor_frame_ego_pose_to_global_matrix = convert_quaternion_to_matrix( + rotation_quaternion=sensor_frame_ego_pose_to_global_rotation, + translation=sensor_frame_ego_pose_to_global_translation, + convert_to_float32=False, + ) + + sensor_to_ego_pose_matrix = convert_quaternion_to_matrix( + rotation_quaternion=sensor_to_ego_pose_rotation, + translation=sensor_to_ego_pose_tranlation, + convert_to_float32=False, + ) + + # Compute the transformation matrix of sensor to the selected sensor coordinate + # Sensor -> sensor frame ego pose -> global -> selected sensor frame ego pose -> selected sensor + # For example, if the sensor is a lidar sweep, and the selected sensor is the top lidar sweep: + # Sweep -> sweep frame ego pose -> global -> top lidar frame ego pose -> top lidar + # Right-to-left multiplication: + sensor_to_selected_sensor_matrix = ( + np.linalg.inv(selected_sensor_to_ego_pose_matrix) + @ np.linalg.inv(selected_sensor_frame_ego_pose_to_global_matrix) + @ sensor_frame_ego_pose_to_global_matrix + @ sensor_to_ego_pose_matrix + ) + return sensor_frame_ego_pose_to_global_matrix, sensor_to_selected_sensor_matrix + + def _extract_lidar_sweeps( + self, lidar_frame_data_model: LidarFrameDataModel + ) -> Sequence[LidarFrameDataModel]: + """ + Extract multisweep lidar metadata from a T4 Sample. + + Args: + t4_sample_record_lidar_info: T4 Sample lidar metadata. + + Returns: + LidarSweepsMetaData: T4 sample lidar sweep metadata + corresponding to the current T4 sample. + """ + + current_lidar_sample_data_token = lidar_frame_data_model.lidar_frame_id + + lidar_frame_data_models = [] + current_sample_data_record: SampleData = self.t4_devkit_dataset.get( + SchemaName.SAMPLE_DATA, current_lidar_sample_data_token + ) + + for _ in range(self.max_sweeps): + # Stop processing if the current lidar sample data has no previous sample data + if not current_sample_data_record.prev: + break + + current_sample_data_record: SampleData = self.t4_devkit_dataset.get( + SchemaName.SAMPLE_DATA, current_sample_data_record.prev + ) + current_cs_record: CalibratedSensor = self.t4_devkit_dataset.get( + SchemaName.CALIBRATED_SENSOR, current_sample_data_record.calibrated_sensor_token + ) + current_lidar_sensor_to_ego_matrix = convert_quaternion_to_matrix( + rotation_quaternion=current_cs_record.rotation, + translation=current_cs_record.translation, + convert_to_float32=False, + ) + + # Get the current lidar sweep frame ego pose + lidar_sweep_transformations = self._compute_sensor_transformation_matrices( + sensor_sample_data_record=current_sample_data_record, + selected_sensor_to_ego_pose_matrix=lidar_frame_data_model.lidar_sensor_to_ego_pose_matrix, + selected_sensor_frame_ego_pose_to_global_matrix=lidar_frame_data_model.lidar_frame_ego_pose_to_global_matrix, + ) + lidar_sweep_frame_ego_pose_to_global_matrix, lidar_sweep_to_lidar_sensor_matrix = ( + lidar_sweep_transformations + ) + + # Inverse it to obtain the transformation matrix + # from the lidar sensor to the lidar sweeps + lidar_sensor_to_lidar_sweep_matrix = np.linalg.inv(lidar_sweep_to_lidar_sensor_matrix) + + lidar_sweep_pointcloud_path = self.t4_devkit_dataset.get_sample_data_path( + sample_data_token=current_sample_data_record.token + ) + + lidar_frame_data_models.append( + LidarFrameDataModel( + lidar_frame_id=current_sample_data_record.token, + lidar_keyframe=False, + lidar_sensor_id=current_cs_record.token, + lidar_sensor_channel_name=lidar_frame_data_model.lidar_sensor_channel_name, + lidar_timestamp_seconds=microseconds2seconds( + current_sample_data_record.timestamp + ), + lidar_pointcloud_path=lidar_sweep_pointcloud_path, + lidar_pointcloud_source_path=None, # Always None for lidar sweeps + lidar_pointcloud_num_features=self.lidar_pointcloud_num_features, + lidar_sensor_to_ego_pose_matrix=current_lidar_sensor_to_ego_matrix, + lidar_frame_ego_pose_to_global_matrix=lidar_sweep_frame_ego_pose_to_global_matrix, + lidar_sensor_to_lidar_sweep_matrices=lidar_sensor_to_lidar_sweep_matrix, + lidar_pointcloud_semantic_mask_path=None, # Always None for lidar sweeps + ) + ) + return lidar_frame_data_models + + def _extract_lidar_sources(self) -> Sequence[LidarSourceDataModel]: + """ + Extract lidar sources metadata from a T4 Sample. + + Args: + sample: T4 Sample. + + Returns: + LidarSourcesMetaData: Lidar sources metadata of the T4 sample. + """ + + # First, read lidar source sensor tokens from the sample data + calibrated_sensor_records: Sequence[CalibratedSensor] = getattr( + self.t4_devkit_dataset, SchemaName.CALIBRATED_SENSOR, [] + ) + + if not len(calibrated_sensor_records): + return [] + + lidar_source_channel_names = [] + lidar_source_data_models = [] + for calibrated_sensor_record in calibrated_sensor_records: + try: + sensor_record: Sensor = self.t4_devkit_dataset.get( + SchemaName.SENSOR, calibrated_sensor_record.sensor_token + ) + except ValueError: + continue + + modality = getattr(sensor_record, self.__MODALITY_STRING, None) + modality_value = getattr(modality, self.__VALUE_STRING, None) + if modality_value != Modality.LIDAR: + continue + + if sensor_record.channel not in lidar_source_channel_names: + lidar_source_channel_names.append(sensor_record.channel) + lidar_source_data_models.append( + LidarSourceDataModel( + channel_name=sensor_record.channel, + sensor_token=sensor_record.token, + translation=calibrated_sensor_record.translation, + rotation=calibrated_sensor_record.rotation.rotation_matrix, + ) + ) + + return lidar_source_data_models + + def _extract_category_mapping(self) -> CategoryMappingDataModel | None: + """ + Extract category metadata from a T4 Sample. + + Args: + sample_index: Sample index. + + Returns: + CategoryMetaData: Category metadata of the T4 sample. + """ + + category_records = self.t4_devkit_dataset.get_table(SchemaName.CATEGORY) + if not len(category_records): + return None + + category_names = [] + category_indices = [] + for category_record in category_records: + category_names.append(category_record.name) + category_indices.append(category_record.index) + + return CategoryMappingDataModel( + category_names=category_names, + category_indices=category_indices, + ) + + def extract_t4_sample_record(self, sample: Sample, sample_index: int) -> T4SampleRecord | None: """ Extract T4 sample record from a T4Dataset. @@ -150,32 +488,41 @@ def extract_t4_sample_record(self, sample: Sample, sample_index: int) -> T4Sampl T4SampleRecord: T4 sample record. """ - # First, read lidar token from the sample data + # Read lidar channel name if LidarChannel.LIDAR_TOP in sample.data: - lidar_token = sample.data[LidarChannel.LIDAR_TOP] + lidar_channel_name = LidarChannel.LIDAR_TOP elif LidarChannel.LIDAR_CONCAT in sample.data: - lidar_token = sample.data[LidarChannel.LIDAR_CONCAT] + lidar_channel_name = LidarChannel.LIDAR_CONCAT else: - raise ValueError( - f"Lidar channel {LidarChannel.LIDAR_TOP} or {LidarChannel.LIDAR_CONCAT} not found in sample data." - ) + return None - # Second, read sample data and calibrated sensor from the T4Dataset - sd_record: SampleData = self.t4_devkit_dataset.get("sample_data", lidar_token) - cs_record: CalibratedSensor = self.t4_devkit_dataset.get( - "calibrated_sensor", sd_record.calibrated_sensor_token + # 1) Extract basic information from the T4Dataset + frame_basic_metadata = self._extract_sample_basic_metadata( + sample=sample, sample_index=sample_index + ) + + # 2) Extract lidar information from the T4Dataset + lidar_frame_data_model = self._extract_lidar_frame( + sample=sample, lidar_channel_name=lidar_channel_name, sample_index=sample_index + ) + + # 3) Extract multisweep lidar information from the T4Dataset + lidar_sweep_data_models = self._extract_lidar_sweeps( + lidar_frame_data_model=lidar_frame_data_model ) - lidar_path, _, _ = self.t4_devkit_dataset.get_sample_data(lidar_token) - # TODO (KokSeang): Extract more information, for example, boxes and lidar sweeps, from the T4Dataset. - # Last, return the T4 sample record + + # Concat lidar frame data models and lidar sweep data models + lidar_frame_data_models = [lidar_frame_data_model] + lidar_sweep_data_models + + # 4) Extract lidar sources information from the T4Dataset + lidar_source_data_models = self._extract_lidar_sources() + + # 5) Extract category information from the T4Dataset + category_mapping_data_model = self._extract_category_mapping() return T4SampleRecord( - scenario_id=self.scenario_data.scenario_id, - sample_id=sample.token, - sample_index=sample_index, - location=self.scenario_data.location, - vehicle_type=self.scenario_data.vehicle_type, - lidar_path=lidar_path, - lidar2ego_translation=cs_record.translation, - lidar2ego_rotation=cs_record.rotation, + frame_basic_metadata=frame_basic_metadata, + lidar_frame_data_models=lidar_frame_data_models, + lidar_source_data_models=lidar_source_data_models, + category_mapping_data_model=category_mapping_data_model, ) diff --git a/autoware_ml/databases/t4dataset/t4sample_records.py b/autoware_ml/databases/t4dataset/t4sample_records.py new file mode 100644 index 0000000..054157a --- /dev/null +++ b/autoware_ml/databases/t4dataset/t4sample_records.py @@ -0,0 +1,55 @@ +# Copyright 2026 TIER IV, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +from pydantic import BaseModel, ConfigDict + +from autoware_ml.databases.schemas.frame_basic_metadata import FrameBasicMetadata +from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord +from autoware_ml.databases.schemas.lidar_frames import LidarFrameDataModel +from autoware_ml.databases.schemas.lidar_sources import LidarSourceDataModel +from autoware_ml.databases.schemas.category_mapping import CategoryMappingDataModel + + +class T4SampleRecord(BaseModel): + """Temporary T4 sample record.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + frame_basic_metadata: FrameBasicMetadata + lidar_frame_data_models: Sequence[LidarFrameDataModel] + lidar_source_data_models: Sequence[LidarSourceDataModel] + category_mapping_data_model: CategoryMappingDataModel + + def to_dataset_record(self) -> DatasetRecord: + """ + Convert this T4SampleRecord to DatasetRecord. + + Returns: + DatasetRecord: Dataset record. + """ + + return DatasetRecord( + scenario_id=self.frame_basic_metadata.scenario_id, + sample_id=self.frame_basic_metadata.sample_id, + sample_index=self.frame_basic_metadata.sample_index, + timestamp_seconds=self.frame_basic_metadata.timestamp_seconds, + scenario_name=self.frame_basic_metadata.scenario_name, + location=self.frame_basic_metadata.location, + vehicle_type=self.frame_basic_metadata.vehicle_type, + lidar_frames=self.lidar_frame_data_models, + lidar_sources=self.lidar_source_data_models, + category_mapping=self.category_mapping_data_model, + ) diff --git a/autoware_ml/utils/dataset.py b/autoware_ml/utils/dataset.py new file mode 100644 index 0000000..3a33f5c --- /dev/null +++ b/autoware_ml/utils/dataset.py @@ -0,0 +1,39 @@ +import numpy as np +import numpy.typing as npt +from pyquaternion import Quaternion + + +def convert_quaternion_to_matrix( + rotation_quaternion: Quaternion, + translation: npt.NDArray[np.float64] | None = None, + convert_to_float32: bool = False, +) -> npt.NDArray[np.float32]: # (4, 4) + """ + Convert a translation and quaternion to a 4x4 transformation matrix. + + Args: + rotation: Quaternion to represent the rotation. + translation (3x1 or None): Translation to represent the translation. + Returns: + npt.NDArray[np.float32]: 4x4 transformation matrix. + """ + + assert isinstance(rotation_quaternion, Quaternion), ( + "Rotation quaternion must be a Quaternion object" + ) + + result = np.eye(4) + result[:3, :3] = rotation_quaternion.rotation_matrix + + if translation is not None: + assert isinstance(translation, np.ndarray), "Translation must be a numpy array or None" + + if translation.shape != (3, 1) and translation.shape != (3,): + raise ValueError(f"Translation must be a 3x1 array, got shape {translation.shape}") + + result[:3, 3] = translation + + if convert_to_float32: + return result.astype(np.float32) + + return result