|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import copy |
16 | | -import os |
17 | 16 | import unittest |
18 | 17 |
|
19 | 18 | import numpy as np |
@@ -121,17 +120,9 @@ def init_dtype_type(self): |
121 | 120 | self.accumulate = False |
122 | 121 |
|
123 | 122 | def setPlace(self): |
124 | | - self.place = [] |
125 | | - if ( |
126 | | - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() |
127 | | - in ['1', 'true', 'on'] |
128 | | - or not paddle.is_compiled_with_cuda() |
129 | | - ): |
130 | | - self.place.append('cpu') |
131 | | - if self.dtype_np is np.float16: |
132 | | - self.place = [] |
133 | | - if paddle.is_compiled_with_cuda(): |
134 | | - self.place.append('gpu') |
| 123 | + self.place = get_places(string_format=True) |
| 124 | + if self.dtype_np is np.float16 and "cpu" in self.place: |
| 125 | + self.place.remove("cpu") |
135 | 126 |
|
136 | 127 | def test_dygraph_forward(self): |
137 | 128 | paddle.disable_static() |
@@ -1028,17 +1019,9 @@ def init_dtype_type(self): |
1028 | 1019 | self.index_type_pd = paddle.int64 |
1029 | 1020 |
|
1030 | 1021 | def setPlace(self): |
1031 | | - self.place = [] |
1032 | | - if ( |
1033 | | - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() |
1034 | | - in ['1', 'true', 'on'] |
1035 | | - or not paddle.is_compiled_with_cuda() |
1036 | | - ): |
1037 | | - self.place.append('cpu') |
1038 | | - if self.dtype_np is np.float16: |
1039 | | - self.place = [] |
1040 | | - if paddle.is_compiled_with_cuda(): |
1041 | | - self.place.append('gpu') |
| 1022 | + self.place = get_places(string_format=True) |
| 1023 | + if self.dtype_np is np.float16 and "cpu" in self.place: |
| 1024 | + self.place.remove("cpu") |
1042 | 1025 |
|
1043 | 1026 | def test_dygraph_forward(self): |
1044 | 1027 | paddle.disable_static() |
|
0 commit comments