|
| 1 | +import importlib |
| 2 | +import inspect |
| 3 | +import sys |
| 4 | +from pathlib import Path |
| 5 | +from typing import get_type_hints |
| 6 | + |
| 7 | +from fluxqueue import Context |
| 8 | + |
| 9 | + |
| 10 | +def get_registry(module_path: str, queue: str, module_dir: str | None = None): |
| 11 | + if module_dir: |
| 12 | + module_dir_path = Path(module_dir).resolve() |
| 13 | + if str(module_dir_path) not in sys.path: |
| 14 | + sys.path.insert(0, str(module_dir_path)) |
| 15 | + |
| 16 | + module = importlib.import_module(module_path) |
| 17 | + registry = {"tasks": {}, "contexts": {}} |
| 18 | + for _name, obj in inspect.getmembers(module): |
| 19 | + if inspect.isfunction(obj): |
| 20 | + task_name = getattr(obj, "task_name", None) |
| 21 | + task_queue = getattr(obj, "queue", None) |
| 22 | + if not task_queue or task_queue != queue: |
| 23 | + continue |
| 24 | + |
| 25 | + if registry["tasks"].get(task_name): |
| 26 | + raise ValueError(f"Task '{task_name}' is duplicated") |
| 27 | + |
| 28 | + original_func = getattr(obj, "__wrapped__", obj) |
| 29 | + |
| 30 | + hints = get_type_hints(original_func) |
| 31 | + sig = inspect.signature(original_func) |
| 32 | + context_params = { |
| 33 | + name: hints[name] |
| 34 | + for name in sig.parameters |
| 35 | + if name in hints |
| 36 | + and isinstance(hints[name], type) |
| 37 | + and issubclass(hints[name], Context) |
| 38 | + } |
| 39 | + if not context_params: |
| 40 | + context_name = None |
| 41 | + else: |
| 42 | + context = context_params[next(iter(context_params))] |
| 43 | + context_name = getattr(context, "__fluxqueue_context__", None) |
| 44 | + |
| 45 | + registry["tasks"][task_name] = { |
| 46 | + "func": original_func, |
| 47 | + "context_name": context_name, |
| 48 | + } |
| 49 | + elif inspect.isclass(obj): |
| 50 | + if not issubclass(obj, Context): |
| 51 | + continue |
| 52 | + |
| 53 | + context_name = getattr(obj, "__fluxqueue_context__", None) |
| 54 | + if registry["contexts"].get(context_name): |
| 55 | + raise ValueError(f"Context '{context_name}' is duplicated") |
| 56 | + |
| 57 | + registry["contexts"][context_name] = obj |
| 58 | + |
| 59 | + return registry |
0 commit comments