Skip to content

Commit cb3aa17

Browse files
authored
Merge pull request #12 from Ericgig/add_context
Add a contextmanager and reverse default option
2 parents f051e51 + 1badaad commit cb3aa17

3 files changed

Lines changed: 117 additions & 17 deletions

File tree

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ see [nvidia's documentation](https://docs.nvidia.com/cuda/cuquantum/latest/getti
3030

3131
## Usage
3232

33-
In simple case, simply calling `set_as_default` before a qutip script should be sufficient to use the backend common solver:
33+
In simple case, simply calling `set_as_default` before a qutip script should be sufficient to use the backend in common solver:
3434

3535
```
3636
import qutip_cuquantum
@@ -39,5 +39,12 @@ from cuquantum.densitymat import WorkStream
3939
qutip_cuquantum.set_as_default(WorkStream())
4040
```
4141

42+
It can also be used as a context:
43+
44+
```
45+
with CuQuantumBackend(ctx):
46+
...
47+
```
48+
4249
qutip-cuquantum work well to speed-up large simulation using `mesolve` or `sesolve`.
4350
However this backend is not compatible with advanced qutip solvers (brmesolve, HEOM) and other various feature.

doc/source/solver.rst

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,23 @@ This is done by calling the ``set_as_default`` function and providing it with a
3232
The ``set_as_default`` function changes several QuTiP defaults to route computations through the cuQuantum library.
3333
This includes setting the default data format for quantum objects (``Qobj``) to ``CuOperator`` and configuring the solvers to use GPU-compatible integrators.
3434

35-
.. warning::
36-
This operation is **not reversible** within the same Python session.
37-
Once the cuQuantum backend is set, all subsequent compatible operations will be dispatched to the GPU.
35+
This operation can be reversed with:
36+
37+
.. code-block:: python
38+
39+
qutip_cuquantum.set_as_default(reverse=True)
40+
41+
42+
The backend can also be enabled with a context:
43+
44+
.. code-block:: python
45+
46+
with CuQuantumBackend(ctx):
47+
...
48+
49+
However be careful when mixing core Qutip object and Qutip-cuQuantum's one.
50+
Qutip's Qobj do not keep all the internal structure needed for cuQuantum's optimizations.
51+
Qutip-cuQuantum's states can be distributed in multiple processes and unusable for many qutip's core features.
3852

3953
==================================
4054
Usage with Solvers

src/qutip_cuquantum/__init__.py

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from qutip.core.options import QutipOptions
2828
from .operator import CuOperator
2929
from .state import CuState
30+
import numpy
3031

3132

3233
# TODO: The split per density is not great
@@ -62,6 +63,9 @@ def CuState_from_CuPyDense(mat):
6263

6364
from .qobjevo import CuQobjEvo
6465
from .ode import Result, CuMCIntegrator
66+
from qutip import settings
67+
from qutip.solver import SESolver, MESolver, MCSolver, Result as BaseResult
68+
from qutip.solver.mcsolve import MCIntegrator
6569

6670

6771
class cuDensityOption(QutipOptions):
@@ -74,7 +78,7 @@ class cuDensityOption(QutipOptions):
7478
cuDensityOption_instance._set_as_global_default()
7579

7680

77-
def set_as_default(ctx: cuquantum.densitymat.WorkStream):
81+
def set_as_default(ctx: cuquantum.densitymat.WorkStream=None, reverse=False):
7882
"""
7983
Update qutip's default to use cuQuantum as a backend.
8084
@@ -83,22 +87,97 @@ def set_as_default(ctx: cuquantum.densitymat.WorkStream):
8387
ctx: WorkStream
8488
A WorkStream instance from cuquantum.density.
8589
It can be set with mpi support for multi-gpu simulations.
90+
Can be ignored when ``reverse=True``.
91+
92+
reverse: bool, default: False
93+
Undo the change of default backend to qutip core defaults.
8694
"""
87-
qutip.settings.cuDensity["ctx"] = ctx
88-
qutip.settings.core["default_dtype"] = "cuDensity"
89-
qutip.settings.core['numpy_backend'] = cupy
95+
if not reverse:
96+
settings.cuDensity["ctx"] = ctx
97+
settings.core["default_dtype"] = "cuDensity"
98+
settings.core['numpy_backend'] = cupy
99+
100+
if True: # if mpi, how to check from ctx?
101+
settings.core["auto_real_casting"] = False
102+
103+
SESolver.solver_options['method'] = "CuVern7"
104+
MESolver.solver_options['method'] = "CuVern7"
105+
MCSolver.solver_options['method'] = "CuVern7"
106+
107+
SESolver._resultclass = Result
108+
MESolver._resultclass = Result
109+
MCSolver._trajectory_resultclass = Result
110+
MCSolver._mc_integrator_class = CuMCIntegrator
111+
112+
else:
113+
settings.core["default_dtype"] = "core"
114+
settings.core['numpy_backend'] = numpy
115+
settings.core["auto_real_casting"] = True
116+
117+
SESolver.solver_options['method'] = "adams"
118+
MESolver.solver_options['method'] = "adams"
119+
MCSolver.solver_options['method'] = "vern7"
120+
121+
SESolver._resultclass = BaseResult
122+
MESolver._resultclass = BaseResult
123+
MCSolver._trajectory_resultclass = BaseResult
124+
MCSolver._mc_integrator_class = MCIntegrator
90125

91-
if True: # if mpi, how to check from ctx?
92-
qutip.settings.core["auto_real_casting"] = False
93126

94-
qutip.SESolver.solver_options['method'] = "CuVern7"
95-
qutip.MESolver.solver_options['method'] = "CuVern7"
96-
qutip.MCSolver.solver_options['method'] = "CuVern7"
97127

98-
qutip.SESolver._resultclass = Result
99-
qutip.MESolver._resultclass = Result
100-
qutip.MCSolver._trajectory_resultclass = Result
101-
qutip.MCSolver._mc_integrator_class = CuMCIntegrator
128+
class CuQuantumBackend:
129+
"""
130+
A context manager class to temporarily set cuQuantum as the default
131+
backend.
132+
133+
Parameters
134+
----------
135+
ctx : cuquantum.densitymat.WorkStream
136+
A WorkStream instance from cuquantum.density.
137+
It can be set with mpi support for multi-gpu simulations.
138+
"""
139+
def __init__(self, ctx):
140+
self.ctx = ctx
141+
self.previous_values = {}
142+
143+
def __enter__(self):
144+
settings.cuDensity["ctx"] = self.ctx
145+
self.previous_values["default_dtype"] = qutip.settings.core["default_dtype"]
146+
settings.core["default_dtype"] = "cuDensity"
147+
self.previous_values["numpy_backend"] = qutip.settings.core["numpy_backend"]
148+
settings.core['numpy_backend'] = cupy
149+
150+
self.previous_values["auto_real"] = settings.core["auto_real_casting"]
151+
if True: # if mpi, how to check from ctx?
152+
settings.core["auto_real_casting"] = False
153+
154+
self.previous_values["SESolverM"] = SESolver.solver_options['method']
155+
self.previous_values["MESolverM"] = MESolver.solver_options['method']
156+
self.previous_values["MCSolverM"] = MCSolver.solver_options['method']
157+
SESolver.solver_options['method'] = "CuVern7"
158+
MESolver.solver_options['method'] = "CuVern7"
159+
MCSolver.solver_options['method'] = "CuVern7"
160+
161+
self.previous_values["SESolverR"] = SESolver._resultclass
162+
self.previous_values["MESolverR"] = MESolver._resultclass
163+
self.previous_values["MCSolverR"] = MCSolver._trajectory_resultclass
164+
self.previous_values["MCSolverI"] = MCSolver._mc_integrator_class
165+
SESolver._resultclass = Result
166+
MESolver._resultclass = Result
167+
MCSolver._trajectory_resultclass = Result
168+
MCSolver._mc_integrator_class = CuMCIntegrator
169+
170+
def __exit__(self, exc_type, exc_value, traceback):
171+
settings.core["default_dtype"] = self.previous_values["default_dtype"]
172+
settings.core['numpy_backend'] = self.previous_values["numpy_backend"]
173+
settings.core["auto_real_casting"] = self.previous_values["auto_real"]
174+
SESolver.solver_options['method'] = self.previous_values["SESolverM"]
175+
MESolver.solver_options['method'] = self.previous_values["MESolverM"]
176+
MCSolver.solver_options['method'] = self.previous_values["MCSolverM"]
177+
SESolver._resultclass = self.previous_values["SESolverR"]
178+
MESolver._resultclass = self.previous_values["MESolverR"]
179+
MCSolver._trajectory_resultclass = self.previous_values["MCSolverR"]
180+
MCSolver._mc_integrator_class = self.previous_values["MCSolverI"]
102181

103182

104183

0 commit comments

Comments
 (0)