Skip to content

Commit 541c458

Browse files
Balandatmeta-codesync[bot]
authored andcommitted
Back out "Phase 3: Add Self type annotations to fluent/clone methods" (#4882)
Summary: Pull Request resolved: #4882 Unlanding this temporarily for compatibility reasons. Reviewed By: tonykao8080 Differential Revision: D92858617 fbshipit-source-id: be6a8fbc3591cbdbd0175b8a4b79212e95907bab
1 parent 8a706b8 commit 541c458

File tree

9 files changed

+51
-56
lines changed

9 files changed

+51
-56
lines changed

ax/core/arm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import hashlib
1010
import json
1111
from collections.abc import Mapping
12-
from typing import Self
1312

1413
from ax.core.types import TParameterization, TParamValue
1514
from ax.utils.common.base import SortableBase
@@ -94,7 +93,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str:
9493
parameters_str = json.dumps(parameters, sort_keys=True)
9594
return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()
9695

97-
def clone(self, clear_name: bool = False) -> Self:
96+
def clone(self, clear_name: bool = False) -> "Arm":
9897
"""Create a copy of this arm.
9998
10099
Args:
@@ -103,7 +102,7 @@ def clone(self, clear_name: bool = False) -> Self:
103102
Defaults to False.
104103
"""
105104
clear_name = clear_name or not self.has_name
106-
return self.__class__(
105+
return Arm(
107106
parameters=self.parameters.copy(), name=None if clear_name else self.name
108107
)
109108

ax/core/base_trial.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def update_stop_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
324324
self._stop_metadata.update(metadata)
325325
return self._stop_metadata
326326

327-
def run(self) -> Self:
327+
def run(self) -> BaseTrial:
328328
"""Deploys the trial according to the behavior on the runner.
329329
330330
The runner returns a `run_metadata` dict containining metadata
@@ -349,7 +349,7 @@ def run(self) -> Self:
349349
self.mark_running()
350350
return self
351351

352-
def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self:
352+
def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
353353
"""Stops the trial according to the behavior on the runner.
354354
355355
The runner returns a `stop_metadata` dict containining metadata
@@ -384,7 +384,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self:
384384
self.mark_as(new_status)
385385
return self
386386

387-
def complete(self, reason: str | None = None) -> Self:
387+
def complete(self, reason: str | None = None) -> BaseTrial:
388388
"""Stops the trial if functionality is defined on runner
389389
and marks trial completed.
390390
@@ -524,7 +524,7 @@ def status_reason(self) -> str | None:
524524
"""Reason string for the trial status (failed, abandoned, or early stopped)."""
525525
return self._status_reason
526526

527-
def mark_staged(self, unsafe: bool = False) -> Self:
527+
def mark_staged(self, unsafe: bool = False) -> BaseTrial:
528528
"""Mark the trial as being staged for running.
529529
530530
Args:
@@ -542,7 +542,7 @@ def mark_staged(self, unsafe: bool = False) -> Self:
542542

543543
def mark_running(
544544
self, no_runner_required: bool = False, unsafe: bool = False
545-
) -> Self:
545+
) -> BaseTrial:
546546
"""Mark trial has started running.
547547
548548
Args:
@@ -572,7 +572,7 @@ def mark_running(
572572

573573
def mark_completed(
574574
self, unsafe: bool = False, time_completed: str | None = None
575-
) -> Self:
575+
) -> BaseTrial:
576576
"""Mark trial as completed.
577577
578578
Args:
@@ -596,7 +596,9 @@ def mark_completed(
596596
)
597597
return self
598598

599-
def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Self:
599+
def mark_abandoned(
600+
self, reason: str | None = None, unsafe: bool = False
601+
) -> BaseTrial:
600602
"""Mark trial as abandoned.
601603
602604
NOTE: Arms in abandoned trials are considered to be 'pending points'
@@ -622,7 +624,7 @@ def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Sel
622624
self._time_completed = datetime.now()
623625
return self
624626

625-
def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self:
627+
def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTrial:
626628
"""Mark trial as failed.
627629
628630
Args:
@@ -642,7 +644,7 @@ def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self:
642644

643645
def mark_early_stopped(
644646
self, reason: str | None = None, unsafe: bool = False
645-
) -> Self:
647+
) -> BaseTrial:
646648
"""Mark trial as early stopped.
647649
648650
Args:
@@ -668,7 +670,7 @@ def mark_early_stopped(
668670
self._time_completed = datetime.now()
669671
return self
670672

671-
def mark_stale(self, unsafe: bool = False) -> Self:
673+
def mark_stale(self, unsafe: bool = False) -> BaseTrial:
672674
"""Mark trial as stale.
673675
674676
Args:
@@ -689,7 +691,9 @@ def mark_stale(self, unsafe: bool = False) -> Self:
689691
self._time_completed = datetime.now()
690692
return self
691693

692-
def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> Self:
694+
def mark_as(
695+
self, status: TrialStatus, unsafe: bool = False, **kwargs: Any
696+
) -> BaseTrial:
693697
"""Mark trial with a new TrialStatus.
694698
695699
Args:
@@ -720,7 +724,7 @@ def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> S
720724
raise TrialMutationError(f"Cannot mark trial as {status}.")
721725
return self
722726

723-
def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self:
727+
def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTrial:
724728
raise NotImplementedError(
725729
"Abandoning arms is only supported for `BatchTrial`. "
726730
"Use `trial.mark_abandoned` if applicable."

ax/core/batch_trial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dataclasses import dataclass
1414
from datetime import datetime
1515
from logging import Logger
16-
from typing import Any, Self, TYPE_CHECKING
16+
from typing import Any, TYPE_CHECKING
1717

1818
import numpy as np
1919
from ax.core.arm import Arm
@@ -467,7 +467,9 @@ def normalized_arm_weights(
467467
weights = weights * (total / np.sum(weights))
468468
return OrderedDict(zip(self.arms, weights))
469469

470-
def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self:
470+
def mark_arm_abandoned(
471+
self, arm_name: str, reason: str | None = None
472+
) -> BatchTrial:
471473
"""Mark a arm abandoned.
472474
473475
Usually done after deployment when one arm causes issues but

ax/core/experiment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections.abc import Hashable, Iterable, Mapping, Sequence
1616
from datetime import datetime
1717
from functools import partial, reduce
18-
from typing import Any, cast, Self, Union
18+
from typing import Any, cast, Union
1919

2020
import ax.core.observation as observation
2121
import pandas as pd
@@ -553,7 +553,7 @@ def immutable_search_space_and_opt_config(self) -> bool:
553553
def tracking_metrics(self) -> list[Metric]:
554554
return list(self._tracking_metrics.values())
555555

556-
def add_tracking_metric(self, metric: Metric) -> Self:
556+
def add_tracking_metric(self, metric: Metric) -> Experiment:
557557
"""Add a new metric to the experiment.
558558
559559
Args:
@@ -576,7 +576,7 @@ def add_tracking_metric(self, metric: Metric) -> Self:
576576
self._tracking_metrics[metric.name] = metric
577577
return self
578578

579-
def add_tracking_metrics(self, metrics: list[Metric]) -> Self:
579+
def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
580580
"""Add a list of new metrics to the experiment.
581581
582582
If any of the metrics are already defined on the experiment,
@@ -591,7 +591,7 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Self:
591591
self.add_tracking_metric(metric)
592592
return self
593593

594-
def update_tracking_metric(self, metric: Metric) -> Self:
594+
def update_tracking_metric(self, metric: Metric) -> Experiment:
595595
"""Redefine a metric that already exists on the experiment.
596596
597597
Args:
@@ -603,7 +603,7 @@ def update_tracking_metric(self, metric: Metric) -> Self:
603603
self._tracking_metrics[metric.name] = metric
604604
return self
605605

606-
def remove_tracking_metric(self, metric_name: str) -> Self:
606+
def remove_tracking_metric(self, metric_name: str) -> Experiment:
607607
"""Remove a metric that already exists on the experiment.
608608
609609
Args:

ax/core/multi_type_experiment.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from collections.abc import Iterable, Sequence
10-
from typing import Any, Self
10+
from typing import Any
1111

1212
from ax.core.arm import Arm
1313
from ax.core.base_trial import BaseTrial, TrialStatus
@@ -96,7 +96,7 @@ def __init__(
9696
default_data_type=default_data_type,
9797
)
9898

99-
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
99+
def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
100100
"""Add a new trial_type to be supported by this experiment.
101101
102102
Args:
@@ -122,7 +122,7 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
122122
self.default_trial_type
123123
)
124124

125-
def update_runner(self, trial_type: str, runner: Runner) -> Self:
125+
def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
126126
"""Update the default runner for an existing trial_type.
127127
128128
Args:
@@ -141,7 +141,7 @@ def add_tracking_metric(
141141
metric: Metric,
142142
trial_type: str | None = None,
143143
canonical_name: str | None = None,
144-
) -> Self:
144+
) -> "MultiTypeExperiment":
145145
"""Add a new metric to the experiment.
146146
147147
Args:
@@ -199,26 +199,18 @@ def add_tracking_metrics(
199199
)
200200
return self
201201

202+
# pyre-fixme[14]: `update_tracking_metric` overrides method defined in
203+
# `Experiment` inconsistently.
202204
def update_tracking_metric(
203-
self,
204-
metric: Metric,
205-
trial_type: str | None = None,
206-
canonical_name: str | None = None,
207-
) -> Self:
205+
self, metric: Metric, trial_type: str, canonical_name: str | None = None
206+
) -> "MultiTypeExperiment":
208207
"""Update an existing metric on the experiment.
209208
210209
Args:
211210
metric: The metric to add.
212-
trial_type: The trial type for which this metric is used. Defaults to
213-
the current trial type of the metric (if set), or the default trial
214-
type otherwise.
211+
trial_type: The trial type for which this metric is used.
215212
canonical_name: The default metric for which this metric is a proxy.
216213
"""
217-
# Default to the existing trial type if not specified
218-
if trial_type is None:
219-
trial_type = self._metric_to_trial_type.get(
220-
metric.name, self._default_trial_type
221-
)
222214
oc = self.optimization_config
223215
oc_metrics = oc.metrics if oc else []
224216
if metric.name in oc_metrics and trial_type != self._default_trial_type:
@@ -231,13 +223,13 @@ def update_tracking_metric(
231223
raise ValueError(f"`{trial_type}` is not a supported trial type.")
232224

233225
super().update_tracking_metric(metric)
234-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
226+
self._metric_to_trial_type[metric.name] = trial_type
235227
if canonical_name is not None:
236228
self._metric_to_canonical_name[metric.name] = canonical_name
237229
return self
238230

239231
@copy_doc(Experiment.remove_tracking_metric)
240-
def remove_tracking_metric(self, metric_name: str) -> Self:
232+
def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment":
241233
if metric_name not in self._tracking_metrics:
242234
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
243235

ax/core/objective.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Iterable
12-
from typing import Self
1312

1413
from ax.core.metric import Metric
1514
from ax.exceptions.core import UserInputError
@@ -73,9 +72,9 @@ def metric_signatures(self) -> list[str]:
7372
"""Get a list of objective metric signatures."""
7473
return [m.signature for m in self.metrics]
7574

76-
def clone(self) -> Self:
75+
def clone(self) -> Objective:
7776
"""Create a copy of the objective."""
78-
return self.__class__(self.metric.clone(), self.minimize)
77+
return Objective(self.metric.clone(), self.minimize)
7978

8079
def __repr__(self) -> str:
8180
return 'Objective(metric_name="{}", minimize={})'.format(
@@ -130,9 +129,9 @@ def objectives(self) -> list[Objective]:
130129
"""Get the objectives."""
131130
return self._objectives
132131

133-
def clone(self) -> Self:
132+
def clone(self) -> MultiObjective:
134133
"""Create a copy of the objective."""
135-
return self.__class__(objectives=[o.clone() for o in self.objectives])
134+
return MultiObjective(objectives=[o.clone() for o in self.objectives])
136135

137136
def __repr__(self) -> str:
138137
return f"MultiObjective(objectives={self.objectives})"
@@ -220,9 +219,9 @@ def expression(self) -> str:
220219

221220
return " + ".join(parts).replace(" + -", " - ")
222221

223-
def clone(self) -> Self:
222+
def clone(self) -> ScalarizedObjective:
224223
"""Create a copy of the objective."""
225-
return self.__class__(
224+
return ScalarizedObjective(
226225
metrics=[m.clone() for m in self.metrics],
227226
weights=self.weights.copy(),
228227
minimize=self.minimize,

ax/core/optimization_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
from itertools import groupby
12-
from typing import Self
1312

1413
from ax.core.arm import Arm
1514
from ax.core.metric import Metric
@@ -81,7 +80,7 @@ def __init__(
8180
self._outcome_constraints: list[OutcomeConstraint] = constraints
8281
self.pruning_target_parameterization = pruning_target_parameterization
8382

84-
def clone(self) -> Self:
83+
def clone(self) -> "OptimizationConfig":
8584
"""Make a copy of this optimization config."""
8685
return self.clone_with_args()
8786

@@ -91,7 +90,7 @@ def clone_with_args(
9190
outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS,
9291
pruning_target_parameterization: Arm
9392
| None = _NO_PRUNING_TARGET_PARAMETERIZATION,
94-
) -> Self:
93+
) -> "OptimizationConfig":
9594
"""Make a copy of this optimization config."""
9695
objective = self.objective.clone() if objective is None else objective
9796
outcome_constraints = (
@@ -105,7 +104,7 @@ def clone_with_args(
105104
else pruning_target_parameterization
106105
)
107106

108-
return self.__class__(
107+
return OptimizationConfig(
109108
objective=objective,
110109
outcome_constraints=outcome_constraints,
111110
pruning_target_parameterization=pruning_target_parameterization,

ax/core/parameter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from enum import Enum
1515
from logging import Logger
1616
from math import inf
17-
from typing import Any, cast, Self, Union
17+
from typing import Any, cast, Union
1818
from warnings import warn
1919

2020
import numpy as np
@@ -237,7 +237,7 @@ def dependents(self) -> dict[TParamValue, list[str]]:
237237
)
238238

239239
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
240-
def clone(self) -> Self:
240+
def clone(self) -> Parameter:
241241
pass
242242

243243
def disable(self, default_value: TParamValue) -> None:

0 commit comments

Comments
 (0)