Skip to content

Commit 3f834d3

Browse files
committed
adds test_opencl_offloading
1 parent 5caea1b commit 3f834d3

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

Diff for: test/unit/test_opencl_offloading.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
pytest.importorskip("pyopencl")
3+
4+
import sys
5+
import petsc4py
6+
petsc4py.init(sys.argv
7+
+ "-viennacl_backend opencl".split()
8+
+ "-viennacl_opencl_device_type cpu".split())
9+
from pyop2 import op2
10+
import pyopencl.array as cla
11+
import numpy as np
12+
13+
14+
def pytest_generate_tests(metafunc):
15+
if "backend" in metafunc.fixturenames:
16+
from pyop2.backends.opencl import opencl_backend
17+
metafunc.parametrize("backend", [opencl_backend])
18+
19+
20+
def test_new_backend_raises_not_implemented_error():
21+
from pyop2.backends import AbstractComputeBackend
22+
unimplemented_backend = AbstractComputeBackend()
23+
24+
attrs = ["GlobalKernel", "Parloop", "Set", "ExtrudedSet", "MixedSet",
25+
"Subset", "DataSet", "MixedDataSet", "Map", "MixedMap", "Dat",
26+
"MixedDat", "DatView", "Mat", "Global", "GlobalDataSet",
27+
"PETScVecType"]
28+
29+
for attr in attrs:
30+
with pytest.raises(NotImplementedError):
31+
getattr(unimplemented_backend, attr)
32+
33+
34+
def test_dat_with_petscvec_representation(backend):
35+
op2.set_offloading_backend(backend)
36+
37+
nelems = 9
38+
data = np.random.rand(nelems)
39+
set_ = op2.compute_backend.Set(nelems)
40+
dset = op2.compute_backend.DataSet(set_, 1)
41+
dat = op2.compute_backend.Dat(dset, data.copy())
42+
43+
assert isinstance(dat.data_ro, np.ndarray)
44+
dat.data[:] *= 3
45+
46+
with op2.offloading():
47+
assert isinstance(dat.data_ro, cla.Array)
48+
dat.data[:] *= 2
49+
50+
assert isinstance(dat.data_ro, np.ndarray)
51+
np.testing.assert_allclose(dat.data_ro, 6*data)
52+
53+
54+
def test_dat_not_as_petscvec(backend):
55+
op2.set_offloading_backend(backend)
56+
57+
nelems = 9
58+
data = np.random.randint(low=-10, high=10,
59+
size=nelems,
60+
dtype=np.int64)
61+
set_ = op2.compute_backend.Set(nelems)
62+
dset = op2.compute_backend.DataSet(set_, 1)
63+
dat = op2.compute_backend.Dat(dset, data.copy())
64+
65+
assert isinstance(dat.data_ro, np.ndarray)
66+
dat.data[:] *= 3
67+
68+
with op2.offloading():
69+
assert isinstance(dat.data_ro, cla.Array)
70+
dat.data[:] *= 2
71+
72+
assert isinstance(dat.data_ro, np.ndarray)
73+
np.testing.assert_allclose(dat.data_ro, 6*data)

0 commit comments

Comments
 (0)