Skip to content

Commit 6e20bd0

Browse files
authored
Merge pull request #318 from bashtage/reenable-pickle
ENH: Allow pickle with NumPy Generator
2 parents 2aa1ddb + 8846a2c commit 6e20bd0

37 files changed

+168
-112
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ The RNGs include:
9898
## Status
9999

100100
* Builds and passes all tests on:
101-
* Linux 32/64 bit, Python 2.7, 3.5, 3.6, 3.7
102-
* Linux (ARM/ARM64), Python 3.7
103-
* OSX 64-bit, Python 2.7, 3.5, 3.6, 3.7
104-
* Windows 32/64 bit, Python 2.7, 3.5, 3.6, 3.7
101+
* Linux 32/64 bit, Python 3.7, 3.8, 3.9, 3.10
102+
* Linux (ARM/ARM64), Python 3.8
103+
* OSX 64-bit, Python 3.9
104+
* Windows 32/64 bit, Python 3.7, 3.8, 3.9, 3.10
105105
* FreeBSD 64-bit
106106

107107
## Version

ci/azure/azure_template_posix.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ jobs:
2626
python.version: '3.8'
2727
coverage: true
2828
NUMPY: 1.17.0
29+
python37_latest:
30+
python.version: '3.7'
31+
python38_latest:
32+
python.version: '3.8'
2933
python39_latest:
3034
python.version: '3.9'
3135
python310_latest:
3236
python.version: '3.10'
33-
python36_latest:
34-
python.version: '3.6'
3537
python38_mid_conda:
3638
python.version: '3.8'
3739
use.conda: true

ci/azure/azure_template_windows.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ jobs:
1616
vmImage: ${{ parameters.vmImage }}
1717
strategy:
1818
matrix:
19+
python37_win_latest:
20+
python.version: '3.7'
1921
python38_win_latest:
2022
python.version: '3.8'
2123
python39_win_latest:
2224
python.version: '3.9'
25+
python310_win_latest:
26+
python.version: '3.10'
2327
maxParallel: 10
2428

2529
steps:

doc/source/change-log.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@ Change Log
1313
You should be using :class:`numpy.random.Generator` or
1414
:class:`numpy.random.RandomState` which are maintained.
1515

16+
v1.23.1
17+
=======
18+
- Registered the bit generators included in ``randomgen`` with NumPy
19+
so that NumPy :class:`~numpy.random.Generator` instances can be pickled
20+
and unpickled when using a ``randomstate`` bit generator.
21+
- Changed the canonical name of the bit generators to be their fully qualified
22+
name. For example, :class:`~randomgen.pcg64.PCG64` is not named ``"randomgen.pcg64.PCG64"``
23+
instead of ``"PCG64"``. This was done to avoid ambiguity with NumPy's supplied
24+
bit generators with the same name.
25+
1626
v1.23.0
1727
=======
1828
- Removed ``Generator`` and ``RandomState``.

randomgen/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from typing import List, Union
44

5+
from randomgen._register import BitGenerators
56
from randomgen.aes import AESCounter
67
from randomgen.chacha import ChaCha
78
from randomgen.dsfmt import DSFMT
@@ -37,6 +38,7 @@
3738

3839
__all__ = [
3940
"AESCounter",
41+
"BitGenerators",
4042
"ChaCha",
4143
"DSFMT",
4244
"EFIIX64",

randomgen/_pickle.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from randomgen.common import BitGenerator
77
from randomgen.dsfmt import DSFMT
88
from randomgen.efiix64 import EFIIX64
9-
from randomgen.generator import ExtendedGenerator, Generator
9+
from randomgen.generator import ExtendedGenerator
1010
from randomgen.hc128 import HC128
1111
from randomgen.jsf import JSF
1212
from randomgen.lxm import LXM
1313
from randomgen.mt64 import MT64
1414
from randomgen.mt19937 import MT19937
15-
from randomgen.mtrand import RandomState
1615
from randomgen.pcg32 import PCG32
1716
from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix
1817
from randomgen.philox import Philox
@@ -54,6 +53,10 @@
5453
"RDRAND": RDRAND,
5554
}
5655

56+
# Assign the fully qualified name for future proofness
57+
for value in list(BitGenerators.values()):
58+
BitGenerators[f"{value.__module__}.{value.__name__}"] = value
59+
5760

5861
def _get_bitgenerator(bit_generator_name: str) -> Type[BitGenerator]:
5962
"""
@@ -75,29 +78,6 @@ def _decode(name: Union[str, bytes]) -> str:
7578
return name.decode("ascii")
7679

7780

78-
def __generator_ctor(bit_generator_name: Union[bytes, str] = "MT19937") -> Generator:
79-
"""
80-
Pickling helper function that returns a Generator object
81-
82-
Parameters
83-
----------
84-
bit_generator_name: str
85-
String containing the core BitGenerator
86-
87-
Returns
88-
-------
89-
rg: Generator
90-
Generator using the named core BitGenerator
91-
"""
92-
bit_generator_name = _decode(bit_generator_name)
93-
assert isinstance(bit_generator_name, str)
94-
bit_generator = _get_bitgenerator(bit_generator_name)
95-
with warnings.catch_warnings():
96-
warnings.filterwarnings("ignore", category=FutureWarning)
97-
bit_gen = bit_generator()
98-
return Generator(bit_gen)
99-
100-
10181
def __extended_generator_ctor(
10282
bit_generator_name: Union[str, bytes] = "MT19937"
10383
) -> ExtendedGenerator:
@@ -146,28 +126,3 @@ def __bit_generator_ctor(
146126
warnings.filterwarnings("ignore", category=FutureWarning)
147127
bit_gen = bit_generator()
148128
return bit_gen
149-
150-
151-
def __randomstate_ctor(
152-
bit_generator_name: Union[str, bytes] = "MT19937"
153-
) -> RandomState:
154-
"""
155-
Pickling helper function that returns a legacy RandomState-like object
156-
157-
Parameters
158-
----------
159-
bit_generator_name: str
160-
String containing the core BitGenerator
161-
162-
Returns
163-
-------
164-
rs: RandomState
165-
Legacy RandomState using the named core BitGenerator
166-
"""
167-
bit_generator_name = _decode(bit_generator_name)
168-
assert isinstance(bit_generator_name, str)
169-
bit_generator = _get_bitgenerator(bit_generator_name)
170-
with warnings.catch_warnings():
171-
warnings.filterwarnings("ignore", category=FutureWarning)
172-
bit_gen = bit_generator()
173-
return RandomState(bit_gen)

randomgen/_register.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from numpy.random._pickle import BitGenerators
2+
3+
from randomgen.aes import AESCounter
4+
from randomgen.chacha import ChaCha
5+
from randomgen.dsfmt import DSFMT
6+
from randomgen.efiix64 import EFIIX64
7+
from randomgen.hc128 import HC128
8+
from randomgen.jsf import JSF
9+
from randomgen.lxm import LXM
10+
from randomgen.mt64 import MT64
11+
from randomgen.mt19937 import MT19937
12+
from randomgen.pcg32 import PCG32
13+
from randomgen.pcg64 import PCG64, PCG64DXSM, LCG128Mix
14+
from randomgen.philox import Philox
15+
from randomgen.rdrand import RDRAND
16+
from randomgen.romu import Romu
17+
from randomgen.sfc import SFC64
18+
from randomgen.sfmt import SFMT
19+
from randomgen.speck128 import SPECK128
20+
from randomgen.threefry import ThreeFry
21+
from randomgen.wrapper import UserBitGenerator
22+
from randomgen.xoroshiro128 import Xoroshiro128
23+
from randomgen.xorshift1024 import Xorshift1024
24+
from randomgen.xoshiro256 import Xoshiro256
25+
from randomgen.xoshiro512 import Xoshiro512
26+
27+
bit_generators = [
28+
AESCounter,
29+
ChaCha,
30+
DSFMT,
31+
EFIIX64,
32+
HC128,
33+
JSF,
34+
LXM,
35+
MT19937,
36+
MT64,
37+
PCG32,
38+
PCG64,
39+
PCG64DXSM,
40+
LCG128Mix,
41+
Philox,
42+
RDRAND,
43+
Romu,
44+
SFC64,
45+
SFMT,
46+
SPECK128,
47+
ThreeFry,
48+
UserBitGenerator,
49+
Xoroshiro128,
50+
Xorshift1024,
51+
Xoshiro256,
52+
Xoshiro512,
53+
]
54+
55+
for bitgen in bit_generators:
56+
key = f"{bitgen.__name__}"
57+
if key not in BitGenerators:
58+
BitGenerators[key] = bitgen
59+
full_key = f"{bitgen.__module__}.{bitgen.__name__}"
60+
BitGenerators[full_key] = bitgen
61+
62+
__all__ = ["BitGenerators"]

randomgen/aes.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ cdef class AESCounter(BitGenerator):
277277
for i in range(16 * 4):
278278
state[i] = self.rng_state.state[i]
279279
offset = self.rng_state.offset
280-
return {"bit_generator": type(self).__name__,
280+
return {"bit_generator": fully_qualified_name(self),
281281
"s": {"state": state, "seed": seed, "counter": counter,
282282
"offset": offset},
283283
"has_uint32": self.rng_state.has_uint32,
@@ -290,7 +290,7 @@ cdef class AESCounter(BitGenerator):
290290
if not isinstance(value, dict):
291291
raise TypeError("state must be a dict")
292292
bitgen = value.get("bit_generator", "")
293-
if bitgen != type(self).__name__:
293+
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
294294
raise ValueError("state must be for a {0} "
295295
"PRNG".format(type(self).__name__))
296296
state =value["s"]["state"]

randomgen/chacha.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ cdef class ChaCha(BitGenerator):
279279
for i in range(2):
280280
ctr[i] = self.rng_state.ctr[i]
281281

282-
return {"bit_generator": type(self).__name__,
282+
return {"bit_generator": fully_qualified_name(self),
283283
"state": {"block": block, "keysetup": keysetup, "ctr": ctr,
284284
"rounds": self.rng_state.rounds}}
285285

@@ -288,7 +288,7 @@ cdef class ChaCha(BitGenerator):
288288
if not isinstance(value, dict):
289289
raise TypeError("state must be a dict")
290290
bitgen = value.get("bit_generator", "")
291-
if bitgen != type(self).__name__:
291+
if bitgen not in (type(self).__name__, fully_qualified_name(self)):
292292
raise ValueError("state must be for a {0} "
293293
"PRNG".format(type(self).__name__))
294294

randomgen/common.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ cdef inline void compute_complex(double *rv_r, double *rv_i, double loc_r,
148148

149149
rv_i[0] = loc_i + scale_i * (rho * rv_r[0] + scale_c * rv_i[0])
150150
rv_r[0] = loc_r + scale_r * rv_r[0]
151+
152+
153+
cdef object fully_qualified_name(instance)

0 commit comments

Comments
 (0)