Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions autoware_ml/common/enums/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
from enum import Enum


class Modality(str, Enum):
"""
Modality.

Attributes:
LIDAR: Lidar modality.
CAMERA: Camera modality.
RADAR: Radar modality.
"""

LIDAR = "lidar"
CAMERA = "camera"
RADAR = "radar"


class SplitType(str, Enum):
"""
Split type.
Expand Down
3 changes: 3 additions & 0 deletions autoware_ml/configs/database/t4dataset/t4dataset_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@ defaults:
- t4dataset/scenarios@scenarios.db_jpntaxi_base: detection3d/db_jpntaxi_base
- t4dataset/scenarios@scenarios.db_largebus: detection3d/db_largebus

# Number of features in the lidar pointcloud
lidar_pointcloud_num_features: 5

# Processor settings
num_workers: 16
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@ defaults:
- t4dataset/scenarios@scenarios.db_j6gen2: detection3d/db_j6gen2
- t4dataset/scenarios@scenarios.db_largebus: detection3d/db_largebus

# Number of features in the lidar pointcloud
lidar_pointcloud_num_features: 5

# Processor settings
num_workers: 16
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ cache_file_prefix_name: database
defaults:
- t4dataset/scenarios@scenarios.db_jpntaxi_base: detection3d/db_jpntaxi_base

# Number of features in the lidar pointcloud
lidar_pointcloud_num_features: 5

# Processor settings
num_workers: 16
2 changes: 1 addition & 1 deletion autoware_ml/databases/base_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import polars as pl

from autoware_ml.databases.scenarios import Scenarios, ScenarioData
from autoware_ml.databases.schemas import DatasetRecord, DatasetTableSchema
from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord, DatasetTableSchema

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion autoware_ml/databases/database_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from types import MappingProxyType

from autoware_ml.databases.scenarios import Scenarios, ScenarioData
from autoware_ml.databases.schemas import DatasetRecord
from autoware_ml.databases.schemas.dataset_schemas import DatasetRecord


class DatabaseInterface(Protocol):
Expand Down
Empty file.
87 changes: 87 additions & 0 deletions autoware_ml/databases/schemas/base_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2026 TIER IV, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import NamedTuple, Sequence, Mapping, Any

import polars as pl


class DatasetTableColumn(NamedTuple):
"""
Annotation table column.

Attributes:
name: Name of the column.
dtype: Data type of the column.
"""

name: str
dtype: pl.DataType


class BaseFieldSchema:
"""
Base class for field schemas.
"""

@classmethod
def to_polars_field_schema(cls) -> Sequence[pl.Field]:
"""
Convert the lidar column schema to a Polars field schema.

Returns:
pl.Schema: Polars schema.
"""

return [
pl.Field(v.name, v.dtype)
for k, v in cls.__dict__.items()
if not k.startswith("__") and isinstance(v, DatasetTableColumn)
]


class DataModelInterface(ABC):
"""
Interface for data models.
"""

@abstractmethod
def to_dictionary(self) -> Mapping[str, Any]:
"""
Convert the data model to a dictionary.

Returns:
Mapping[str, Any]: Dictionary representation of the data model.
"""

raise NotImplementedError("Subclasses must implement to_dictionary!")

@classmethod
def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> DataModelInterface:
"""
Load the data model and decode it to the corresponding data model from a dictionary, which is
deserialized from a Polars dataframe.

Args:
data_model: Dictionary representation of the data model, which is
deserialized from a Polars dataframe.

Returns:
DataModelInterface: Data model.
"""

raise NotImplementedError("Subclasses must implement load_from_dictionary!")
85 changes: 85 additions & 0 deletions autoware_ml/databases/schemas/category_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2026 TIER IV, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence, Any, Mapping

import polars as pl
from pydantic import BaseModel, ConfigDict

from autoware_ml.databases.schemas.base_schemas import (
BaseFieldSchema,
DatasetTableColumn,
DataModelInterface,
)


@dataclass(frozen=True)
class CategoryMappingDatasetSchema(BaseFieldSchema):
"""
Dataclass to define polars schema for columns related to category mapping.
"""

CATEGORY_NAMES = DatasetTableColumn("category_names", pl.List(pl.String))
CATEGORY_INDICES = DatasetTableColumn("category_indices", pl.List(pl.Int32))


class CategoryMappingDataModel(BaseModel, DataModelInterface):
"""
Category mapping data model that can be shared by multiple datasets. It saves the mapping
between category names and category indices.

Attributes:
category_names: List of category names.
category_indices: List of category indices.
"""

model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)

category_names: Sequence[str]
category_indices: Sequence[int]

def model_post_init(self, __context: Any) -> None:
"""Validate that all attributes are of the same length."""

assert len(self.category_names) == len(self.category_indices), (
"All attributes must be of the same length"
)

def to_dictionary(self) -> Mapping[str, Any]:
"""
Convert the category mapping data model to a dictionary.

Returns:
Mapping[str, Any]: Dictionary representation of the category mapping data model.
"""

return self.model_dump()

@classmethod
def load_from_dictionary(cls, data_model: Mapping[str, Any]) -> CategoryMappingDataModel:
"""
Load the category mapping data model and decode it to the corresponding CategoryMappingDataModel
from a dictionary, which is deserialized from a Polars dataframe.

Args:
data_model: Dictionary representation of the category mapping data model, which is
deserialized from a Polars dataframe.
"""
return cls(
category_names=data_model[CategoryMappingDatasetSchema.CATEGORY_NAMES.name],
category_indices=data_model[CategoryMappingDatasetSchema.CATEGORY_INDICES.name],
)
Loading
Loading