diff --git a/qiskit/primitives/primitive_job.py b/qiskit/primitives/primitive_job.py index 580ab5930df5..bc316cfde61d 100644 --- a/qiskit/primitives/primitive_job.py +++ b/qiskit/primitives/primitive_job.py @@ -48,10 +48,15 @@ def _submit(self): self._future = executor.submit(self._function, *self._args, **self._kwargs) executor.shutdown(wait=False) - def _prepare_dump(self): - """This method allows PrimitiveJob to be serialized""" + def __getstate__(self): _ = self.result() _ = self.status() + state = self.__dict__.copy() + state["_future"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) self._future = None def result(self) -> ResultT: diff --git a/test/python/primitives/test_primitive_job.py b/test/python/primitives/test_primitive_job.py new file mode 100644 index 000000000000..c76dcc33d87d --- /dev/null +++ b/test/python/primitives/test_primitive_job.py @@ -0,0 +1,51 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for PrimitiveJob.""" + +import pickle +from test import QiskitTestCase + +import numpy as np +from ddt import data, ddt + +from qiskit import QuantumCircuit +from qiskit.primitives import PrimitiveJob, StatevectorSampler + + +@ddt +class TestPrimitiveJob(QiskitTestCase): + """Tests PrimitiveJob.""" + + @data(1, 2, 3) + def test_serialize(self, size): + """Test serialize.""" + n = 2 + qc = QuantumCircuit(n) + qc.h(range(n)) + qc.measure_all() + sampler = StatevectorSampler() + job = sampler.run([qc] * size) + obj = pickle.dumps(job) + job2 = pickle.loads(obj) + self.assertIsInstance(job2, PrimitiveJob) + self.assertEqual(job.job_id(), job2.job_id()) + self.assertEqual(job.status(), job2.status()) + self.assertEqual(job.metadata, job2.metadata) + result = job.result() + result2 = job2.result() + self.assertEqual(result.metadata, result2.metadata) + self.assertEqual(len(result), len(result2)) + for i in range(len(result)): + self.assertEqual(result[i].metadata, result2[i].metadata) + self.assertEqual(result[i].data.keys(), result2[i].data.keys()) + np.testing.assert_allclose(result[i].join_data().array, result2[i].join_data().array)