Skip to content

Commit 5ae5ceb

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Import OrderedDict from collections
Summary: `typing.OrderedDict` is a deprecated alias of `collections.OrderedDict`: https://docs.python.org/3/library/typing.html#typing.OrderedDict Reviewed By: Balandat Differential Revision: D56795174 fbshipit-source-id: 5e2f64271e01ebd76b1578a22a40cde18b4582e1
1 parent 34df921 commit 5ae5ceb

File tree

5 files changed

+10
-9
lines changed

5 files changed

+10
-9
lines changed

ax/models/torch/botorch_modular/surrogate.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from __future__ import annotations
1010

1111
import inspect
12+
from collections import OrderedDict
1213
from collections.abc import Sequence
1314
from copy import deepcopy
1415
from logging import Logger
15-
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union
16+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1617

1718
import torch
1819
from ax.core.search_space import SearchSpaceDigest

ax/models/torch/botorch_modular/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
# pyre-strict
88

99
import warnings
10+
from collections import OrderedDict
1011
from collections.abc import Sequence
1112
from logging import Logger
12-
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type
13+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
1314

1415
import torch
1516
from ax.core.search_space import SearchSpaceDigest
@@ -463,4 +464,4 @@ def subset_state_dict(
463464
for k, v in state_dict.items()
464465
if k.startswith(expected_substring)
465466
]
466-
return OrderedDict(new_items) # pyre-ignore [29]: T168826187
467+
return OrderedDict(new_items)

ax/models/torch/tests/test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
# pyre-strict
88

99
import dataclasses
10+
from collections import OrderedDict
1011
from contextlib import ExitStack
1112
from copy import deepcopy
12-
from typing import Dict, OrderedDict, Type
13+
from typing import Dict, Type
1314
from unittest import mock
1415
from unittest.mock import Mock
1516

@@ -397,7 +398,6 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
397398

398399
old_surrogate = self.model.surrogates[Keys.ONLY_SURROGATE]
399400
old_surrogate._model = mock.MagicMock()
400-
# pyre-ignore [29]: T168826187
401401
old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"})
402402

403403
for refit_on_cv, warm_start_refit in [

ax/models/torch/tests/test_surrogate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import dataclasses
1010
import math
11-
from typing import Any, Dict, OrderedDict, Tuple, Type
11+
from collections import OrderedDict
12+
from typing import Any, Dict, Tuple, Type
1213
from unittest.mock import MagicMock, Mock, patch
1314

1415
import numpy as np
@@ -954,7 +955,6 @@ def test_fit(
954955

955956
# Should `load_state_dict` when `state_dict` is not `None`
956957
# and `refit` is `False`.
957-
# pyre-ignore [29]: T168826187
958958
state_dict = OrderedDict({"state_attribute": torch.ones(2)})
959959
surrogate._submodels = {} # Prevent re-use of fitted model.
960960
surrogate.fit(

ax/models/torch/tests/test_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import warnings
10-
from typing import OrderedDict
10+
from collections import OrderedDict
1111

1212
import numpy as np
1313
import torch
@@ -615,7 +615,6 @@ def test_subset_state_dict(self) -> None:
615615
m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
616616
m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
617617
model_list = ModelListGP(m0, m1)
618-
# pyre-ignore [6]: T168826187
619618
model_list_state_dict = checked_cast(OrderedDict, model_list.state_dict())
620619
# Subset the model dict from model list and check that it is correct.
621620
m0_state_dict = model_list.models[0].state_dict()

0 commit comments

Comments
 (0)