-
-
Notifications
You must be signed in to change notification settings - Fork 1
Change simulation encoding workaround to official rocketpy encoders #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
9339a43
fb23890
ba7c28b
7c7f805
6ecc358
5eab472
8450082
d1f9a97
1ff3f2d
6342c68
8703230
b5924f9
93e19ee
114cc6e
40aed61
a26e01c
b574253
2510936
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,48 +1,151 @@ | ||
| # fork of https://github.com/encode/starlette/blob/master/starlette/middleware/gzip.py | ||
| import gzip | ||
| import io | ||
| import logging | ||
| import json | ||
| import copy | ||
| from datetime import datetime | ||
|
|
||
| from typing import Annotated, NoReturn, Any | ||
| import numpy as np | ||
| from typing import NoReturn, Tuple | ||
|
|
||
| from pydantic import PlainSerializer | ||
| from rocketpy import Function | ||
| from rocketpy._encoders import RocketPyEncoder | ||
| from starlette.datastructures import Headers, MutableHeaders | ||
| from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def to_python_primitive(v: Any) -> Any: | ||
|
|
||
| class DiscretizeConfig: | ||
| """ | ||
| Configuration class for RocketPy function discretization. | ||
|
|
||
| This class allows easy configuration of discretization parameters | ||
| for different types of RocketPy objects and their callable attributes. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, bounds: Tuple[float, float] = (0, 10), samples: int = 200 | ||
| ): | ||
| self.bounds = bounds | ||
| self.samples = samples | ||
|
|
||
| @classmethod | ||
| def for_environment(cls) -> 'DiscretizeConfig': | ||
| return cls(bounds=(0, 50000), samples=100) | ||
|
|
||
| @classmethod | ||
| def for_motor(cls) -> 'DiscretizeConfig': | ||
| return cls(bounds=(0, 10), samples=150) | ||
|
|
||
| @classmethod | ||
| def for_rocket(cls) -> 'DiscretizeConfig': | ||
| return cls(bounds=(0, 1), samples=100) | ||
|
|
||
| @classmethod | ||
| def for_flight(cls) -> 'DiscretizeConfig': | ||
| return cls(bounds=(0, 30), samples=200) | ||
|
|
||
|
|
||
| def rocketpy_encoder(obj, config: DiscretizeConfig = DiscretizeConfig()): | ||
GabrielBarberini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Convert complex types to Python primitives. | ||
| Encode a RocketPy object using official RocketPy encoders. | ||
|
|
||
| This function creates a copy of the object, discretizes callable Function | ||
| attributes on the copy, and then uses RocketPy's official RocketPyEncoder for | ||
| complete object serialization. The original object remains unchanged. | ||
|
|
||
| Args: | ||
| v: Any value, particularly those with a 'source' attribute | ||
| containing numpy arrays or generic types. | ||
| obj: RocketPy object (Environment, Motor, Rocket, Flight) | ||
| config: DiscretizeConfig object with discretization parameters (optional) | ||
|
|
||
| Returns: | ||
| The primitive representation of the input value. | ||
| Dictionary of encoded attributes | ||
| """ | ||
| if hasattr(v, "source"): | ||
| if isinstance(v.source, np.ndarray): | ||
| return v.source.tolist() | ||
|
|
||
| if isinstance(v.source, (np.generic,)): | ||
| return v.source.item() | ||
| # Create a copy to avoid mutating the original object | ||
| obj_copy = copy.deepcopy(obj) | ||
|
|
||
GabrielBarberini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for attr_name in dir(obj_copy): | ||
| if attr_name.startswith('_'): | ||
| continue | ||
|
|
||
| return str(v.source) | ||
| try: | ||
| attr_value = getattr(obj_copy, attr_name) | ||
| except Exception: | ||
| continue | ||
|
|
||
| if isinstance(v, (np.generic,)): | ||
| return v.item() | ||
| if callable(attr_value) and isinstance(attr_value, Function): | ||
| try: | ||
| discretized_func = Function(attr_value.source) | ||
GabrielBarberini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| discretized_func.set_discrete( | ||
| lower=config.bounds[0], | ||
| upper=config.bounds[1], | ||
| samples=config.samples, | ||
| mutate_self=True, | ||
| ) | ||
|
|
||
| setattr(obj_copy, attr_name, discretized_func) | ||
|
|
||
| if isinstance(v, (np.ndarray,)): | ||
| return v.tolist() | ||
| except Exception as e: | ||
| logger.warning(f"Failed to discretize {attr_name}: {e}") | ||
|
|
||
| try: | ||
| json_str = json.dumps( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @phmbressan @GabrielBarberini not sure who this is for but when i run the env class simulation I get an incorrect date format here get_instance_attributes parses it as a natural datetime object. Is this a encoder issue? Observed this behavior while printing this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems an issue within |
||
| obj_copy, | ||
| cls=RocketPyEncoder, | ||
| include_outputs=True, | ||
| include_function_data=True, | ||
| ) | ||
| encoded_result = json.loads(json_str) | ||
|
|
||
| return str(v) | ||
| # Post-process to fix datetime fields that got converted to lists | ||
| return _fix_datetime_fields(encoded_result) | ||
| except Exception as e: | ||
|
Comment on lines
+98
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naive datetimes reintroduced by post-processing; make them UTC-aware _fix_datetime_fields builds naive datetime objects, conflicting with the new UTC-aware defaults elsewhere (e.g., Environment defaults). This inconsistency can break consumers expecting tz-aware timestamps. Apply this diff: -from datetime import datetime
+from datetime import datetime, timezone
@@
- fixed[key] = datetime(
- year, month, day, hour, minute, second, microsecond
- )
+ fixed[key] = datetime(
+ year, month, day, hour, minute, second, microsecond, tzinfo=timezone.utc
+ )Also applies to: 123-155 🤖 Prompt for AI Agents |
||
| logger.warning(f"Failed to encode with RocketPyEncoder: {e}") | ||
| attributes = {} | ||
| for attr_name in dir(obj_copy): | ||
| if not attr_name.startswith('_'): | ||
| try: | ||
| attr_value = getattr(obj_copy, attr_name) | ||
| if not callable(attr_value): | ||
| attributes[attr_name] = str(attr_value) | ||
GabrielBarberini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except Exception: | ||
| continue | ||
| return attributes | ||
|
|
||
|
|
||
| AnyToPrimitive = Annotated[ | ||
| Any, | ||
| PlainSerializer(to_python_primitive), | ||
| ] | ||
| def _fix_datetime_fields(data): | ||
| """ | ||
| Fix datetime fields that RocketPyEncoder converted to lists. | ||
| """ | ||
| if isinstance(data, dict): | ||
| fixed = {} | ||
| for key, value in data.items(): | ||
| if ( | ||
| key in ['date', 'local_date', 'datetime_date'] | ||
| and isinstance(value, list) | ||
| and len(value) >= 3 | ||
| ): | ||
| # Convert [year, month, day, hour, ...] back to datetime | ||
| try: | ||
| year, month, day = value[0:3] | ||
| hour = value[3] if len(value) > 3 else 0 | ||
| minute = value[4] if len(value) > 4 else 0 | ||
| second = value[5] if len(value) > 5 else 0 | ||
| microsecond = value[6] if len(value) > 6 else 0 | ||
|
|
||
| fixed[key] = datetime( | ||
| year, month, day, hour, minute, second, microsecond | ||
| ) | ||
| except (ValueError, TypeError, IndexError): | ||
| # If conversion fails, keep the original value | ||
| fixed[key] = value | ||
| else: | ||
| fixed[key] = _fix_datetime_fields(value) | ||
| return fixed | ||
| if isinstance(data, (list, tuple)): | ||
| return [_fix_datetime_fields(item) for item in data] | ||
| return data | ||
|
|
||
|
|
||
| class RocketPyGZipMiddleware: | ||
|
|
@@ -70,6 +173,7 @@ async def __call__( | |
|
|
||
|
|
||
| class GZipResponder: | ||
| # fork of https://github.com/encode/starlette/blob/master/starlette/middleware/gzip.py | ||
| def __init__( | ||
| self, app: ASGIApp, minimum_size: int, compresslevel: int = 9 | ||
| ) -> None: | ||
|
|
@@ -161,6 +265,13 @@ async def send_with_gzip(self, message: Message) -> None: | |
|
|
||
| await self.send(message) | ||
|
|
||
| else: | ||
| # Pass through other message types unmodified. | ||
| if not self.started: | ||
| self.started = True | ||
| await self.send(self.initial_message) | ||
| await self.send(message) | ||
|
|
||
|
|
||
| async def unattached_send(message: Message) -> NoReturn: | ||
| raise RuntimeError("send awaitable not set") # pragma: no cover | ||
Uh oh!
There was an error while loading. Please reload this page.