Skip to content

Commit dd2e192

Browse files
jeongyoonleeclaude
andcommitted
fix: remove sklearn.utils._random import to avoid DEFAULT_SEED signature mismatch
- Copy our_rand_r and RAND_R_MAX implementations locally - Avoids sklearn 1.6+ DEFAULT_SEED const qualifier change - Maintains BSD-3-Clause license compatibility Co-Authored-By: Claude (claude-sonnet-4-5) <noreply@anthropic.com>
1 parent 50c640d commit dd2e192

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

causalml/inference/tree/_tree/_utils.pyx

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,37 @@ import numpy as np
1616
cimport numpy as cnp
1717
cnp.import_array()
1818

19-
from sklearn.utils._random cimport our_rand_r
19+
# Random number generation utilities
20+
# Copied from sklearn.utils._random to avoid DEFAULT_SEED signature mismatch
21+
# Original authors: The scikit-learn developers
22+
# License: BSD-3-Clause
23+
# Copied from sklearn 1.6+ _random.pxd to avoid signature mismatch issues
24+
25+
from ._typedefs cimport uint32_t
26+
27+
cdef const uint32_t DEFAULT_SEED = 1
28+
29+
cdef enum:
30+
# Max value for our rand_r replacement.
31+
# Corresponds to the maximum representable value for
32+
# 32-bit signed integers (i.e. 2^31 - 1).
33+
RAND_R_MAX = 2147483647
34+
35+
# rand_r replacement using a 32bit XorShift generator
36+
# See http://www.jstatsoft.org/v08/i14/paper for details
37+
cdef inline uint32_t our_rand_r(uint32_t* seed) nogil:
38+
"""Generate a pseudo-random np.uint32 from a np.uint32 seed"""
39+
# seed shouldn't ever be 0.
40+
if (seed[0] == 0):
41+
seed[0] = DEFAULT_SEED
42+
43+
seed[0] ^= <uint32_t>(seed[0] << 13)
44+
seed[0] ^= <uint32_t>(seed[0] >> 17)
45+
seed[0] ^= <uint32_t>(seed[0] << 5)
46+
47+
# Use the modulo to ensure we don't return values greater than
48+
# the maximum representable value for signed 32bit integers.
49+
return seed[0] % ((<uint32_t>RAND_R_MAX) + 1)
2050

2151
# =============================================================================
2252
# Helper functions

0 commit comments

Comments
 (0)