Skip to content

Commit 26af0ee

Browse files
authored
Replace some op_test.get_places() - part11 (PaddlePaddle#73843)
1 parent 3e51521 commit 26af0ee

File tree

1 file changed

+6
-23
lines changed

1 file changed

+6
-23
lines changed

test/legacy_test/test_index_put_op.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import copy
16-
import os
1716
import unittest
1817

1918
import numpy as np
@@ -121,17 +120,9 @@ def init_dtype_type(self):
121120
self.accumulate = False
122121

123122
def setPlace(self):
124-
self.place = []
125-
if (
126-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
127-
in ['1', 'true', 'on']
128-
or not paddle.is_compiled_with_cuda()
129-
):
130-
self.place.append('cpu')
131-
if self.dtype_np is np.float16:
132-
self.place = []
133-
if paddle.is_compiled_with_cuda():
134-
self.place.append('gpu')
123+
self.place = get_places(string_format=True)
124+
if self.dtype_np is np.float16 and "cpu" in self.place:
125+
self.place.remove("cpu")
135126

136127
def test_dygraph_forward(self):
137128
paddle.disable_static()
@@ -1028,17 +1019,9 @@ def init_dtype_type(self):
10281019
self.index_type_pd = paddle.int64
10291020

10301021
def setPlace(self):
1031-
self.place = []
1032-
if (
1033-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
1034-
in ['1', 'true', 'on']
1035-
or not paddle.is_compiled_with_cuda()
1036-
):
1037-
self.place.append('cpu')
1038-
if self.dtype_np is np.float16:
1039-
self.place = []
1040-
if paddle.is_compiled_with_cuda():
1041-
self.place.append('gpu')
1022+
self.place = get_places(string_format=True)
1023+
if self.dtype_np is np.float16 and "cpu" in self.place:
1024+
self.place.remove("cpu")
10421025

10431026
def test_dygraph_forward(self):
10441027
paddle.disable_static()

0 commit comments

Comments
 (0)