Skip to content

Commit bd34a31

Browse files
authored
feat: add subinterpreters support (#245)
1 parent 9e33e16 commit bd34a31

File tree

21 files changed

+734
-135
lines changed

21 files changed

+734
-135
lines changed

.github/workflows/build.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ on:
66
- main # allow to trigger the workflow with tag push event
77
pull_request:
88
types:
9+
- labeled
10+
- unlabeled
911
- opened
1012
- synchronize
1113
- reopened
@@ -33,6 +35,11 @@ on:
3335
- build-only
3436
- build-and-publish
3537
required: true
38+
nightly-pybind11:
39+
description: "Use nightly pybind11"
40+
type: boolean
41+
required: false
42+
default: false
3643

3744
permissions:
3845
contents: read
@@ -85,6 +92,20 @@ jobs:
8592
- name: Print version
8693
run: python setup.py --version
8794

95+
- name: Use nightly pybind11
96+
shell: bash
97+
if: |
98+
(github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') ||
99+
(
100+
github.event_name == 'pull_request' &&
101+
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
102+
)
103+
run: |
104+
python -m pip install --force-reinstall "$(python .github/workflows/set_setup_requires.py)"
105+
echo "::group::pyproject.toml"
106+
cat pyproject.toml
107+
echo "::endgroup::"
108+
88109
- name: Install dependencies
89110
run: python -m pip install --upgrade pip setuptools wheel build
90111

@@ -246,6 +267,20 @@ jobs:
246267
- name: Print version
247268
run: python setup.py --version
248269

270+
- name: Use nightly pybind11
271+
shell: bash
272+
if: |
273+
(github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') ||
274+
(
275+
github.event_name == 'pull_request' &&
276+
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
277+
)
278+
run: |
279+
python -m pip install --force-reinstall "$(python .github/workflows/set_setup_requires.py)"
280+
echo "::group::pyproject.toml"
281+
cat pyproject.toml
282+
echo "::endgroup::"
283+
249284
- name: Set CIBW_BUILD
250285
shell: bash
251286
run: |

.github/workflows/lint.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ on:
1313
- reopened
1414
# Allow to trigger the workflow manually
1515
workflow_dispatch:
16+
inputs:
17+
nightly-pybind11:
18+
description: "Use nightly pybind11"
19+
type: boolean
20+
required: false
21+
default: false
1622

1723
permissions:
1824
contents: read
@@ -72,8 +78,11 @@ jobs:
7278
- name: Install nightly pybind11
7379
shell: bash
7480
if: |
75-
github.event_name == 'pull_request' &&
76-
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
81+
(github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') ||
82+
(
83+
github.event_name == 'pull_request' &&
84+
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
85+
)
7786
run: |
7887
python -m pip install --force-reinstall "$(python .github/workflows/set_setup_requires.py)"
7988
echo "::group::pyproject.toml"

.github/workflows/tests-with-pydebug.yml

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ on:
44
push:
55
branches:
66
- main
7+
schedule:
8+
# Run at 12:00 Asia/Shanghai (04:00 UTC) every three days with nightly pybind11
9+
- cron: "0 4 */3 * *"
710
pull_request:
811
types:
912
- labeled
@@ -24,6 +27,12 @@ on:
2427
- .github/workflows/tests-with-pydebug.yml
2528
# Allow to trigger the workflow manually
2629
workflow_dispatch:
30+
inputs:
31+
nightly-pybind11:
32+
description: "Use nightly pybind11"
33+
type: boolean
34+
required: false
35+
default: false
2736

2837
permissions:
2938
contents: read
@@ -41,7 +50,6 @@ env:
4150
PYTHON: "python" # to be updated
4251
PYTHON_TAG: "py3" # to be updated
4352
PYTHON_VERSION: "3" # to be updated
44-
pybind11_VERSION: "stable" # to be updated
4553
PYENV_ROOT: "~/.pyenv" # to be updated
4654
VENV_BIN_NAME: "bin" # to be updated
4755
COLUMNS: "100"
@@ -312,11 +320,14 @@ jobs:
312320
- name: Use nightly pybind11
313321
shell: bash
314322
if: |
315-
github.event_name == 'pull_request' &&
316-
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
323+
github.event_name == 'schedule' ||
324+
(github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') ||
325+
(
326+
github.event_name == 'pull_request' &&
327+
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
328+
)
317329
run: |
318330
source "venv/${VENV_BIN_NAME}/activate"
319-
${{ env.PYTHON }} .github/workflows/set_setup_requires.py
320331
${{ env.PYTHON }} -m pip install --force-reinstall "$(${{ env.PYTHON }} .github/workflows/set_setup_requires.py)"
321332
echo "::group::pyproject.toml"
322333
cat pyproject.toml
@@ -348,7 +359,13 @@ jobs:
348359
"--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
349360
"--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml"
350361
)
351-
make test PYTESTOPTS="${PYTESTOPTS[*]}"
362+
363+
if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then
364+
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov"
365+
make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'"
366+
else
367+
make test PYTESTOPTS="${PYTESTOPTS[*]}"
368+
fi
352369
353370
CORE_DUMP_FILES="$(
354371
find . -type d -path "./venv" -prune \

.github/workflows/tests.yml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ on:
2525
- .github/workflows/tests.yml
2626
# Allow to trigger the workflow manually
2727
workflow_dispatch:
28+
inputs:
29+
nightly-pybind11:
30+
description: "Use nightly pybind11"
31+
type: boolean
32+
required: false
33+
default: false
2834

2935
permissions:
3036
contents: read
@@ -42,7 +48,6 @@ env:
4248
PYTHONUNBUFFERED: "1"
4349
PYTHON: "python" # to be updated
4450
PYTHON_TAG: "py3" # to be updated
45-
pybind11_VERSION: "stable" # to be updated
4651
VENV_BIN_NAME: "bin" # to be updated
4752
COLUMNS: "100"
4853
FORCE_COLOR: "1"
@@ -186,10 +191,12 @@ jobs:
186191
- name: Use nightly pybind11
187192
shell: bash
188193
if: |
189-
github.event_name == 'pull_request' &&
190-
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
194+
(github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') ||
195+
(
196+
github.event_name == 'pull_request' &&
197+
contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11')
198+
)
191199
run: |
192-
${{ env.PYTHON }} .github/workflows/set_setup_requires.py
193200
${{ env.PYTHON }} -m pip install --force-reinstall "$(${{ env.PYTHON }} .github/workflows/set_setup_requires.py)"
194201
echo "::group::pyproject.toml"
195202
cat pyproject.toml

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ repos:
3838
hooks:
3939
- id: cpplint
4040
- repo: https://github.com/astral-sh/ruff-pre-commit
41-
rev: v0.14.14
41+
rev: v0.15.2
4242
hooks:
4343
- id: ruff-check
4444
args: [--fix, --exit-non-zero-on-fix]

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16-
-
16+
- Add subinterpreters support for Python 3.14+ by [@XuehaiPan](https://github.com/XuehaiPan) in [#245](https://github.com/metaopt/optree/pull/245).
1717

1818
### Changed
1919

include/optree/pymacros.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717

1818
#pragma once
1919

20+
#include <stdexcept> // std::runtime_error
21+
2022
#include <Python.h>
2123

2224
#include <pybind11/pybind11.h>
@@ -32,6 +34,15 @@ limitations under the License.
3234
// NOLINTNEXTLINE[bugprone-macro-parentheses]
3335
#define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0))
3436

37+
#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \
38+
(PYBIND11_VERSION_HEX >= 0x030002F0 /* pybind11 3.0.2 */) && \
39+
(defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
40+
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT))
41+
# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1
42+
#else
43+
# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT
44+
#endif
45+
3546
namespace py = pybind11;
3647

3748
#if !defined(Py_ALWAYS_INLINE)
@@ -59,3 +70,50 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept {
5970
return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x);
6071
}
6172
#define Py_IsConstant(x) Py_IsConstant(x)
73+
74+
using interpid_t = decltype(PyInterpreterState_GetID(nullptr));
75+
76+
#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \
77+
NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)
78+
79+
[[nodiscard]] inline bool IsCurrentPyInterpreterMain() {
80+
return PyInterpreterState_Get() == PyInterpreterState_Main();
81+
}
82+
83+
[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() {
84+
PyInterpreterState *interp = PyInterpreterState_Get();
85+
if (PyErr_Occurred() != nullptr) [[unlikely]] {
86+
throw py::error_already_set();
87+
}
88+
if (interp == nullptr) [[unlikely]] {
89+
throw std::runtime_error("Failed to get the current Python interpreter state.");
90+
}
91+
const interpid_t interpid = PyInterpreterState_GetID(interp);
92+
if (PyErr_Occurred() != nullptr) [[unlikely]] {
93+
throw py::error_already_set();
94+
}
95+
return interpid;
96+
}
97+
98+
[[nodiscard]] inline interpid_t GetMainPyInterpreterID() {
99+
PyInterpreterState *interp = PyInterpreterState_Main();
100+
if (PyErr_Occurred() != nullptr) [[unlikely]] {
101+
throw py::error_already_set();
102+
}
103+
if (interp == nullptr) [[unlikely]] {
104+
throw std::runtime_error("Failed to get the main Python interpreter state.");
105+
}
106+
const interpid_t interpid = PyInterpreterState_GetID(interp);
107+
if (PyErr_Occurred() != nullptr) [[unlikely]] {
108+
throw py::error_already_set();
109+
}
110+
return interpid;
111+
}
112+
113+
#else
114+
115+
[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; }
116+
[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; }
117+
[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; }
118+
119+
#endif

include/optree/registry.h

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ limitations under the License.
2323
#include <string> // std::string
2424
#include <unordered_map> // std::unordered_map
2525
#include <unordered_set> // std::unordered_set
26-
#include <utility> // std::pair
26+
#include <utility> // std::pair, std::make_pair
2727

2828
#include <pybind11/pybind11.h>
2929

3030
#include "optree/exceptions.h"
3131
#include "optree/hashing.h"
32+
#include "optree/pymacros.h"
3233
#include "optree/synchronization.h"
3334

3435
namespace optree {
@@ -141,6 +142,52 @@ class PyTreeTypeRegistry {
141142
return count1;
142143
}
143144

145+
// Get the number of alive interpreters that have seen the registry.
146+
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() {
147+
const scoped_read_lock lock{sm_mutex};
148+
return py::ssize_t_cast(sm_alive_interpids.size());
149+
}
150+
151+
// Get the number of interpreters that have seen the registry.
152+
[[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() {
153+
const scoped_read_lock lock{sm_mutex};
154+
return sm_num_interpreters_seen;
155+
}
156+
157+
// Get the IDs of alive interpreters that have seen the registry.
158+
[[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set<interpid_t>
159+
GetAliveInterpreterIDs() {
160+
const scoped_read_lock lock{sm_mutex};
161+
return sm_alive_interpids;
162+
}
163+
164+
// Check if should preserve the insertion order of the dictionary keys during flattening.
165+
[[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered(
166+
const std::string &registry_namespace,
167+
const bool &inherit_global_namespace = true) {
168+
const scoped_read_lock lock{sm_dict_order_mutex};
169+
170+
const auto interpid = GetCurrentPyInterpreterID();
171+
const auto &namespaces = sm_dict_insertion_ordered_namespaces;
172+
return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) ||
173+
(inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end());
174+
}
175+
176+
// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
177+
static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered(
178+
const bool &mode,
179+
const std::string &registry_namespace) {
180+
const scoped_write_lock lock{sm_dict_order_mutex};
181+
182+
const auto interpid = GetCurrentPyInterpreterID();
183+
const auto key = std::make_pair(interpid, registry_namespace);
184+
if (mode) [[likely]] {
185+
sm_dict_insertion_ordered_namespaces.insert(key);
186+
} else [[unlikely]] {
187+
sm_dict_insertion_ordered_namespaces.erase(key);
188+
}
189+
}
190+
144191
friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references]
145192

146193
private:
@@ -173,7 +220,16 @@ class PyTreeTypeRegistry {
173220
NamedRegistrationsMap m_named_registrations{};
174221
BuiltinsTypesSet m_builtins_types{};
175222

223+
// A set of namespaces that preserve the insertion order of the dictionary keys during
224+
// flattening.
225+
static inline std::unordered_set<std::pair<interpid_t, std::string>>
226+
sm_dict_insertion_ordered_namespaces{};
227+
static inline read_write_mutex sm_dict_order_mutex{};
228+
friend class PyTreeSpec;
229+
230+
static inline std::unordered_set<interpid_t> sm_alive_interpids{};
176231
static inline read_write_mutex sm_mutex{};
232+
static inline ssize_t sm_num_interpreters_seen = 0;
177233
};
178234

179235
} // namespace optree

include/optree/synchronization.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ using scoped_recursive_lock = std::scoped_lock<recursive_mutex>;
6262
#if (defined(__APPLE__) /* header <shared_mutex> is not available on macOS build target */ && \
6363
PY_VERSION_HEX < 0x030C00F0 /* Python 3.12.0 */)
6464

65-
# undef HAVE_READ_WRITE_LOCK
65+
# undef OPTREE_HAS_READ_WRITE_LOCK
6666

6767
using read_write_mutex = mutex;
6868
using scoped_read_lock = scoped_lock;
6969
using scoped_write_lock = scoped_lock;
7070

7171
#else
7272

73-
# define HAVE_READ_WRITE_LOCK
73+
# define OPTREE_HAS_READ_WRITE_LOCK 1
7474

7575
# include <shared_mutex> // std::shared_mutex, std::shared_lock
7676

0 commit comments

Comments
 (0)