Skip to content

Commit d6e116a

Browse files
Cache cl.Context on OpenCL
1 parent 81c33eb commit d6e116a

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

xobjects/context.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,15 @@ def get_context_from_string(ctxstr):
668668

669669
if ctxstr is None:
670670
return xo.ContextCpu()
671+
672+
ll = ctxstr.split(":")
673+
if len(ll) <= 1:
674+
ctxtype = ll[0]
675+
option = []
671676
else:
672-
ll = ctxstr.split(":")
673-
if len(ll) <= 1:
674-
ctxtype = ll[0]
675-
option = []
676-
else:
677-
ctxtype, options = ctxstr.split(":")
678-
option = options.split(",")
677+
ctxtype, options = ctxstr.split(":")
678+
option = options.split(",")
679+
679680
if ctxtype == "ContextCpu":
680681
if len(option) == 0:
681682
return xo.ContextCpu()

xobjects/context_pyopencl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def _build_view(cls, a):
8181

8282

8383
class ContextPyopencl(XContext):
84+
context_cache = {}
85+
8486
@property
8587
def nplike_array_type(self):
8688
return cla.Array
@@ -128,20 +130,23 @@ def __init__(
128130
super().__init__()
129131

130132
# TODO assume one device only
131-
if device is None:
133+
if device in self.context_cache:
134+
self.platform, self.device, self.context = self.context_cache[device]
135+
elif device is None:
132136
self.context = cl.create_some_context(interactive=False)
133137
self.device = self.context.devices[0]
134138
self.platform = self.device.platform
135139
else:
136140
if isinstance(device, str):
137-
platform, device = map(int, device.split("."))
141+
platform, device_ = map(int, device.split("."))
138142
self.platform = cl.get_platforms()[platform]
139-
self.device = self.platform.get_devices()[device]
143+
self.device = self.platform.get_devices()[device_]
140144
else:
141145
self.device = device
142146
self.platform = device.platform
143147

144148
self.context = cl.Context([self.device])
149+
self.context_cache[device] = self.platform, self.device, self.context
145150

146151
self.queue = cl.CommandQueue(self.context)
147152

0 commit comments

Comments
 (0)