Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ApiBaseModel(BaseModel, ABC):
validate_default=True,
validate_all_in_root=True,
validate_assignment=True,
ser_json_exclude_none=True,
)

def set_id(self, value):
Expand Down
6 changes: 4 additions & 2 deletions src/repositories/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ def __init__(self):

@repository_exception_handler
async def create_flight(self, flight: FlightModel) -> str:
return await self.insert(flight.model_dump())
return await self.insert(flight.model_dump(exclude_none=True))

@repository_exception_handler
async def read_flight_by_id(self, flight_id: str) -> Optional[FlightModel]:
return await self.find_by_id(data_id=flight_id)

@repository_exception_handler
async def update_flight_by_id(self, flight_id: str, flight: FlightModel):
await self.update_by_id(flight.model_dump(), data_id=flight_id)
await self.update_by_id(
flight.model_dump(exclude_none=True), data_id=flight_id
)

@repository_exception_handler
async def delete_flight_by_id(self, flight_id: str):
Expand Down
6 changes: 4 additions & 2 deletions src/repositories/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ def __init__(self):

@repository_exception_handler
async def create_motor(self, motor: MotorModel) -> str:
return await self.insert(motor.model_dump())
return await self.insert(motor.model_dump(exclude_none=True))

@repository_exception_handler
async def read_motor_by_id(self, motor_id: str) -> Optional[MotorModel]:
return await self.find_by_id(data_id=motor_id)

@repository_exception_handler
async def update_motor_by_id(self, motor_id: str, motor: MotorModel):
await self.update_by_id(motor.model_dump(), data_id=motor_id)
await self.update_by_id(
motor.model_dump(exclude_none=True), data_id=motor_id
)

@repository_exception_handler
async def delete_motor_by_id(self, motor_id: str):
Expand Down
6 changes: 4 additions & 2 deletions src/repositories/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ def __init__(self):

@repository_exception_handler
async def create_rocket(self, rocket: RocketModel) -> str:
return await self.insert(rocket.model_dump())
return await self.insert(rocket.model_dump(exclude_none=True))

@repository_exception_handler
async def read_rocket_by_id(self, rocket_id: str) -> Optional[RocketModel]:
return await self.find_by_id(data_id=rocket_id)

@repository_exception_handler
async def update_rocket_by_id(self, rocket_id: str, rocket: RocketModel):
await self.update_by_id(rocket.model_dump(), data_id=rocket_id)
await self.update_by_id(
rocket.model_dump(exclude_none=True), data_id=rocket_id
)

@repository_exception_handler
async def delete_rocket_by_id(self, rocket_id: str):
Expand Down
6 changes: 4 additions & 2 deletions src/services/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import dill

from rocketpy.environment.environment import Environment as RocketPyEnvironment
from rocketpy.utilities import get_instance_attributes
from src.models.environment import EnvironmentModel
from src.views.environment import EnvironmentSimulation
from src.utils import rocketpy_encoder, DiscretizeConfig


class EnvironmentService:
Expand Down Expand Up @@ -50,7 +50,9 @@ def get_environment_simulation(self) -> EnvironmentSimulation:
EnvironmentSimulation
"""

attributes = get_instance_attributes(self.environment)
attributes = rocketpy_encoder(
self.environment, DiscretizeConfig.for_environment()
)
env_simulation = EnvironmentSimulation(**attributes)
return env_simulation

Expand Down
6 changes: 4 additions & 2 deletions src/services/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import dill

from rocketpy.simulation.flight import Flight as RocketPyFlight
from rocketpy.utilities import get_instance_attributes

from src.services.environment import EnvironmentService
from src.services.rocket import RocketService
from src.models.flight import FlightModel
from src.views.flight import FlightSimulation
from src.utils import rocketpy_encoder, DiscretizeConfig


class FlightService:
Expand Down Expand Up @@ -55,7 +55,9 @@ def get_flight_simulation(self) -> FlightSimulation:
Returns:
FlightSimulation
"""
attributes = get_instance_attributes(self.flight)
attributes = rocketpy_encoder(
self.flight, DiscretizeConfig.for_flight()
)
flight_simulation = FlightSimulation(**attributes)
return flight_simulation

Expand Down
4 changes: 2 additions & 2 deletions src/services/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from rocketpy.motors.solid_motor import SolidMotor
from rocketpy.motors.liquid_motor import LiquidMotor
from rocketpy.motors.hybrid_motor import HybridMotor
from rocketpy.utilities import get_instance_attributes
from rocketpy import (
LevelBasedTank,
MassBasedTank,
Expand All @@ -18,6 +17,7 @@
from src.models.sub.tanks import TankKinds
from src.models.motor import MotorKinds, MotorModel
from src.views.motor import MotorSimulation
from src.utils import rocketpy_encoder, DiscretizeConfig


class MotorService:
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_motor_simulation(self) -> MotorSimulation:
Returns:
MotorSimulation
"""
attributes = get_instance_attributes(self.motor)
attributes = rocketpy_encoder(self.motor, DiscretizeConfig.for_motor())
motor_simulation = MotorSimulation(**attributes)
return motor_simulation

Expand Down
6 changes: 4 additions & 2 deletions src/services/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
Fins as RocketPyFins,
Tail as RocketPyTail,
)
from rocketpy.utilities import get_instance_attributes

from src import logger
from src.models.rocket import RocketModel, Parachute
from src.models.sub.aerosurfaces import NoseCone, Tail, Fins
from src.services.motor import MotorService
from src.views.rocket import RocketSimulation
from src.utils import rocketpy_encoder, DiscretizeConfig


class RocketService:
Expand Down Expand Up @@ -107,7 +107,9 @@ def get_rocket_simulation(self) -> RocketSimulation:
Returns:
RocketSimulation
"""
attributes = get_instance_attributes(self.rocket)
attributes = rocketpy_encoder(
self.rocket, DiscretizeConfig.for_rocket()
)
rocket_simulation = RocketSimulation(**attributes)
return rocket_simulation

Expand Down
159 changes: 135 additions & 24 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,151 @@
# fork of https://github.com/encode/starlette/blob/master/starlette/middleware/gzip.py
import gzip
import io
import logging
import json
import copy
from datetime import datetime

from typing import Annotated, NoReturn, Any
import numpy as np
from typing import NoReturn, Tuple

from pydantic import PlainSerializer
from rocketpy import Function
from rocketpy._encoders import RocketPyEncoder
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send

logger = logging.getLogger(__name__)

def to_python_primitive(v: Any) -> Any:

class DiscretizeConfig:
"""
Configuration class for RocketPy function discretization.

This class allows easy configuration of discretization parameters
for different types of RocketPy objects and their callable attributes.
"""

def __init__(
self, bounds: Tuple[float, float] = (0, 10), samples: int = 200
):
self.bounds = bounds
self.samples = samples

@classmethod
def for_environment(cls) -> 'DiscretizeConfig':
return cls(bounds=(0, 50000), samples=100)

@classmethod
def for_motor(cls) -> 'DiscretizeConfig':
return cls(bounds=(0, 10), samples=150)

@classmethod
def for_rocket(cls) -> 'DiscretizeConfig':
return cls(bounds=(0, 1), samples=100)

@classmethod
def for_flight(cls) -> 'DiscretizeConfig':
return cls(bounds=(0, 30), samples=200)


def rocketpy_encoder(obj, config: DiscretizeConfig = DiscretizeConfig()):
"""
Convert complex types to Python primitives.
Encode a RocketPy object using official RocketPy encoders.

This function creates a copy of the object, discretizes callable Function
attributes on the copy, and then uses RocketPy's official RocketPyEncoder for
complete object serialization. The original object remains unchanged.

Args:
v: Any value, particularly those with a 'source' attribute
containing numpy arrays or generic types.
obj: RocketPy object (Environment, Motor, Rocket, Flight)
config: DiscretizeConfig object with discretization parameters (optional)

Returns:
The primitive representation of the input value.
Dictionary of encoded attributes
"""
if hasattr(v, "source"):
if isinstance(v.source, np.ndarray):
return v.source.tolist()

if isinstance(v.source, (np.generic,)):
return v.source.item()
# Create a copy to avoid mutating the original object
obj_copy = copy.deepcopy(obj)

for attr_name in dir(obj_copy):
if attr_name.startswith('_'):
continue

return str(v.source)
try:
attr_value = getattr(obj_copy, attr_name)
except Exception:
continue

if isinstance(v, (np.generic,)):
return v.item()
if callable(attr_value) and isinstance(attr_value, Function):
try:
discretized_func = Function(attr_value.source)
discretized_func.set_discrete(
lower=config.bounds[0],
upper=config.bounds[1],
samples=config.samples,
mutate_self=True,
)

setattr(obj_copy, attr_name, discretized_func)

if isinstance(v, (np.ndarray,)):
return v.tolist()
except Exception as e:
logger.warning(f"Failed to discretize {attr_name}: {e}")

try:
json_str = json.dumps(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phmbressan @GabrielBarberini not sure who this is for but when i run the env class simulation I get an incorrect date format here

"date": [2025, 6, 9, 23]

get_instance_attributes parses it as a natural datetime object. Is this a encoder issue? Observed this behavior while printing this json_str variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems an issue within RocketPyEncoder, I've solved it with a temporary post-processing hack here: 8703230

obj_copy,
cls=RocketPyEncoder,
include_outputs=True,
include_function_data=True,
)
encoded_result = json.loads(json_str)

return str(v)
# Post-process to fix datetime fields that got converted to lists
return _fix_datetime_fields(encoded_result)
except Exception as e:
Comment on lines +98 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Naive datetimes reintroduced by post-processing; make them UTC-aware

_fix_datetime_fields builds naive datetime objects, conflicting with the new UTC-aware defaults elsewhere (e.g., Environment defaults). This inconsistency can break consumers expecting tz-aware timestamps.

Apply this diff:

-from datetime import datetime
+from datetime import datetime, timezone
@@
-                    fixed[key] = datetime(
-                        year, month, day, hour, minute, second, microsecond
-                    )
+                    fixed[key] = datetime(
+                        year, month, day, hour, minute, second, microsecond, tzinfo=timezone.utc
+                    )

Also applies to: 123-155

🤖 Prompt for AI Agents
In src/utils.py around lines 98 to 109 (and also apply same change to lines
123-155), _fix_datetime_fields currently reconstructs datetime objects as naive
datetimes; update the post-processing so any datetime reconstructed from
lists/tuples is made timezone-aware by setting UTC (e.g., use datetime(...,
tzinfo=timezone.utc) or .replace(tzinfo=timezone.utc)) instead of leaving them
naive, ensuring all returned datetimes are UTC-aware and consistent with
Environment defaults; adjust any imports and tests as needed to use
datetime.timezone.utc.

logger.warning(f"Failed to encode with RocketPyEncoder: {e}")
attributes = {}
for attr_name in dir(obj_copy):
if not attr_name.startswith('_'):
try:
attr_value = getattr(obj_copy, attr_name)
if not callable(attr_value):
attributes[attr_name] = str(attr_value)
except Exception:
continue
return attributes


AnyToPrimitive = Annotated[
Any,
PlainSerializer(to_python_primitive),
]
def _fix_datetime_fields(data):
"""
Fix datetime fields that RocketPyEncoder converted to lists.
"""
if isinstance(data, dict):
fixed = {}
for key, value in data.items():
if (
key in ['date', 'local_date', 'datetime_date']
and isinstance(value, list)
and len(value) >= 3
):
# Convert [year, month, day, hour, ...] back to datetime
try:
year, month, day = value[0:3]
hour = value[3] if len(value) > 3 else 0
minute = value[4] if len(value) > 4 else 0
second = value[5] if len(value) > 5 else 0
microsecond = value[6] if len(value) > 6 else 0

fixed[key] = datetime(
year, month, day, hour, minute, second, microsecond
)
except (ValueError, TypeError, IndexError):
# If conversion fails, keep the original value
fixed[key] = value
else:
fixed[key] = _fix_datetime_fields(value)
return fixed
if isinstance(data, (list, tuple)):
return [_fix_datetime_fields(item) for item in data]
return data


class RocketPyGZipMiddleware:
Expand Down Expand Up @@ -70,6 +173,7 @@ async def __call__(


class GZipResponder:
# fork of https://github.com/encode/starlette/blob/master/starlette/middleware/gzip.py
def __init__(
self, app: ASGIApp, minimum_size: int, compresslevel: int = 9
) -> None:
Expand Down Expand Up @@ -161,6 +265,13 @@ async def send_with_gzip(self, message: Message) -> None:

await self.send(message)

else:
# Pass through other message types unmodified.
if not self.started:
self.started = True
await self.send(self.initial_message)
await self.send(message)


async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
Loading
Loading