|
25 | 25 | from ax.exceptions.core import UserInputError |
26 | 26 | from ax.generators.types import TConfig |
27 | 27 | from ax.utils.common.constants import Keys |
28 | | -from pandas import DataFrame |
| 28 | +from pandas import DataFrame, Series |
29 | 29 | from pyre_extensions import assert_is_instance, none_throws |
30 | 30 |
|
31 | 31 | if TYPE_CHECKING: |
@@ -314,11 +314,38 @@ def transform_experiment_data( |
314 | 314 | for p, param in self.search_space.parameters.items() |
315 | 315 | } |
316 | 316 | arm_data = arm_data.astype(dtype=column_to_type) |
317 | | - # Round to digits if any parameter specifies it. |
| 317 | + # Snap to the parameter's grid (digits or step_size) if specified. |
| 318 | + # These mirror ``RangeParameter.cast``'s rounding logic, but are applied |
| 319 | + # in a vectorized manner over the whole column rather than via a per-row |
| 320 | + # ``Series.apply`` (which calls ``parameter.cast`` once per element and is |
| 321 | + # slow for large DataFrames). NaN / ``<NA>`` values (added for missing |
| 322 | + # columns during the ``reindex`` above) propagate through ``round`` and |
| 323 | + # the arithmetic, matching the previous ``value if value is None`` guard. |
318 | 324 | for p_name in parameter_names: |
319 | 325 | parameter = self.search_space.parameters[p_name] |
320 | | - if isinstance(parameter, RangeParameter) and parameter.digits is not None: |
321 | | - arm_data[p_name] = arm_data[p_name].round(parameter.digits) |
| 326 | + if not isinstance(parameter, RangeParameter): |
| 327 | + continue |
| 328 | + column: Series = arm_data[p_name] |
| 329 | + if ( |
| 330 | + parameter.parameter_type is ParameterType.FLOAT |
| 331 | + and parameter.digits is not None |
| 332 | + ): |
| 333 | + # ``Series.round`` uses round-half-to-even, same as Python's |
| 334 | + # built-in ``round`` used in ``RangeParameter.cast``. |
| 335 | + arm_data[p_name] = column.round(parameter.digits) |
| 336 | + elif parameter.step_size is not None: |
| 337 | + # Snap to the grid ``{lower + k * step_size : k in Z}`` by |
| 338 | + # rounding ``(value - lower) / step_size`` to the nearest integer. |
| 339 | + lower = float(parameter.lower) |
| 340 | + step_size = none_throws(parameter.step_size) |
| 341 | + steps: Series = column.sub(lower).div(step_size).round() |
| 342 | + snapped: Series = steps.mul(step_size).add(lower) |
| 343 | + if parameter.parameter_type is ParameterType.INT: |
| 344 | + # Preserve the nullable ``Int64`` dtype so reindex-added |
| 345 | + # ``<NA>`` values survive the cast. |
| 346 | + arm_data[p_name] = snapped.round().astype("Int64") |
| 347 | + else: |
| 348 | + arm_data[p_name] = snapped |
322 | 349 |
|
323 | 350 | return ExperimentData(arm_data=arm_data, observation_data=observation_data) |
324 | 351 |
|
|
0 commit comments