|
26 | 26 | ) |
27 | 27 | from pyomo.common.gc_manager import PauseGC |
28 | 28 | from pyomo.common.modeling import unique_component_name |
29 | | -from pyomo.common.dependencies import ( |
30 | | - dill, |
31 | | - dill_available, |
32 | | - multiprocessing, |
33 | | - attempt_import, |
34 | | -) |
| 29 | +from pyomo.common.dependencies import dill, dill_available, multiprocessing |
35 | 30 | from pyomo.common.enums import SolverAPIVersion |
36 | 31 |
|
37 | 32 | from pyomo.core import ( |
@@ -858,17 +853,16 @@ def _setup_pool(self, threads, instance, num_jobs): |
858 | 853 | "methods 'spawn' or 'forkserver', but it could " |
859 | 854 | "not be imported." |
860 | 855 | ) |
861 | | - solver_cls_path = ( |
862 | | - f"{self._config.solver.__class__.__module__}:" |
863 | | - f"{self._config.solver.__class__.__name__}" |
864 | | - ) |
865 | | - solver_opts = dict(self._config.solver.options) |
866 | | - use_pb = bool(self._config.use_primal_bound) |
867 | | - |
868 | 856 | pool = multiprocessing.get_context(method.value).Pool( |
869 | 857 | processes=threads, |
870 | 858 | initializer=_setup_spawn, |
871 | | - initargs=(dill.dumps(instance), solver_cls_path, solver_opts, use_pb), |
| 859 | + initargs=( |
| 860 | + dill.dumps(instance), |
| 861 | + f"{self._config.solver.__class__.__module__}:" |
| 862 | + f"{self._config.solver.__class__.__name__}", |
| 863 | + dill.dumps(self._config.solver.options), |
| 864 | + self._config.use_primal_bound, |
| 865 | + ), |
872 | 866 | ) |
873 | 867 | elif method == ProcessStartMethod.fork: |
874 | 868 | _thread_local.model = instance |
@@ -955,47 +949,24 @@ def get_all_M_values(self, model): |
955 | 949 |
|
956 | 950 | # Things we call in subprocesses. These can't be member functions, or |
957 | 951 | # else we'd have to pickle `self`, which is problematic. |
958 | | -def _setup_spawn(serialized_model, solver_cls_path, solver_opts, use_primal_bound): |
| 952 | +def _setup_spawn(model, solver_class_path, solver_options, use_primal_bound): |
959 | 953 | # When using 'spawn' or 'forkserver', Python starts in a new |
960 | 954 | # environment and executes only this file, so we need to manually |
961 | 955 | # ensure necessary plugins are registered (even if the main process |
962 | 956 | # has already registered them). |
963 | 957 | import pyomo.environ |
| 958 | + from importlib import import_module |
964 | 959 |
|
965 | | - _thread_local.model = dill.loads(serialized_model) |
966 | | - |
967 | | - mod_name, sep, attr = solver_cls_path.partition(":") |
968 | | - if not sep or not attr: |
969 | | - raise ValueError( |
970 | | - f"Bad entrypoint '{solver_cls_path}' (expected 'module:Class')" |
971 | | - ) |
| 960 | + global _thread_local |
972 | 961 |
|
973 | | - module, available = attempt_import(mod_name) |
974 | | - if not available: |
975 | | - raise ImportError( |
976 | | - f"Could not import module '{mod_name}' for '{solver_cls_path}'" |
977 | | - ) |
| 962 | + # Reconstruct the model and solver in the new process |
| 963 | + _thread_local.model = dill.loads(model) |
978 | 964 |
|
979 | | - SolverCls = getattr(module, attr) |
980 | | - |
981 | | - # Construct the solver and apply options (works across V1 solvers) |
982 | | - try: |
983 | | - solver = SolverCls(options=solver_opts or {}) |
984 | | - except TypeError: |
985 | | - solver = SolverCls() |
986 | | - if hasattr(solver, "set_options") and callable(solver.set_options): |
987 | | - solver.set_options(solver_opts or {}) |
988 | | - elif hasattr(solver, "options") and isinstance(solver.options, dict): |
989 | | - solver.options.update(solver_opts or {}) |
990 | | - elif solver_opts: |
991 | | - for k, v in solver_opts.items(): |
992 | | - try: |
993 | | - solver.options[k] = v |
994 | | - except Exception: |
995 | | - pass |
| 965 | + module_path, _, class_name = solver_class_path.partition(":") |
| 966 | + solver_cls = getattr(import_module(module_path), class_name) |
996 | 967 |
|
997 | | - _thread_local.solver = solver |
998 | | - _thread_local.config_use_primal_bound = bool(use_primal_bound) |
| 968 | + _thread_local.solver = solver_cls(options=solver_options) |
| 969 | + _thread_local.config_use_primal_bound = use_primal_bound |
999 | 970 |
|
1000 | 971 |
|
1001 | 972 | def _setup_fork(): |
|
0 commit comments