@@ -23,12 +23,13 @@ limitations under the License.
2323#include < string> // std::string
2424#include < unordered_map> // std::unordered_map
2525#include < unordered_set> // std::unordered_set
26- #include < utility> // std::pair
26+ #include < utility> // std::pair, std::make_pair
2727
2828#include < pybind11/pybind11.h>
2929
3030#include " optree/exceptions.h"
3131#include " optree/hashing.h"
32+ #include " optree/pymacros.h"
3233#include " optree/synchronization.h"
3334
3435namespace optree {
@@ -141,6 +142,52 @@ class PyTreeTypeRegistry {
141142 return count1;
142143 }
143144
145+ // Get the number of alive interpreters that have seen the registry.
146+ [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive () {
147+ const scoped_read_lock lock{sm_mutex};
148+ return py::ssize_t_cast (sm_alive_interpids.size ());
149+ }
150+
151+ // Get the number of interpreters that have seen the registry.
152+ [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen () {
153+ const scoped_read_lock lock{sm_mutex};
154+ return sm_num_interpreters_seen;
155+ }
156+
157+ // Get the IDs of alive interpreters that have seen the registry.
158+ [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set<interpid_t >
159+ GetAliveInterpreterIDs () {
160+ const scoped_read_lock lock{sm_mutex};
161+ return sm_alive_interpids;
162+ }
163+
164+ // Check if should preserve the insertion order of the dictionary keys during flattening.
165+ [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered (
166+ const std::string ®istry_namespace,
167+ const bool &inherit_global_namespace = true ) {
168+ const scoped_read_lock lock{sm_dict_order_mutex};
169+
170+ const auto interpid = GetCurrentPyInterpreterID ();
171+ const auto &namespaces = sm_dict_insertion_ordered_namespaces;
172+ return (namespaces.find ({interpid, registry_namespace}) != namespaces.end ()) ||
173+ (inherit_global_namespace && namespaces.find ({interpid, " " }) != namespaces.end ());
174+ }
175+
176+ // Set the namespace to preserve the insertion order of the dictionary keys during flattening.
177+ static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered (
178+ const bool &mode,
179+ const std::string ®istry_namespace) {
180+ const scoped_write_lock lock{sm_dict_order_mutex};
181+
182+ const auto interpid = GetCurrentPyInterpreterID ();
183+ const auto key = std::make_pair (interpid, registry_namespace);
184+ if (mode) [[likely]] {
185+ sm_dict_insertion_ordered_namespaces.insert (key);
186+ } else [[unlikely]] {
187+ sm_dict_insertion_ordered_namespaces.erase (key);
188+ }
189+ }
190+
144191 friend void BuildModule (py::module_ &mod); // NOLINT[runtime/references]
145192
146193private:
@@ -173,7 +220,16 @@ class PyTreeTypeRegistry {
173220 NamedRegistrationsMap m_named_registrations{};
174221 BuiltinsTypesSet m_builtins_types{};
175222
223+ // A set of namespaces that preserve the insertion order of the dictionary keys during
224+ // flattening.
225+ static inline std::unordered_set<std::pair<interpid_t , std::string>>
226+ sm_dict_insertion_ordered_namespaces{};
227+ static inline read_write_mutex sm_dict_order_mutex{};
228+ friend class PyTreeSpec ;
229+
230+ static inline std::unordered_set<interpid_t > sm_alive_interpids{};
176231 static inline read_write_mutex sm_mutex{};
232+ static inline ssize_t sm_num_interpreters_seen = 0 ;
177233};
178234
179235} // namespace optree
0 commit comments