diff --git a/metaflow/cmd/debug_cli.py b/metaflow/cmd/debug_cli.py new file mode 100644 index 00000000000..1541a71873d --- /dev/null +++ b/metaflow/cmd/debug_cli.py @@ -0,0 +1,137 @@ +import json +import os + +from metaflow._vendor import click + +METAFLOW_ATTACH_CONFIG = { + "name": "Metaflow: Attach", + "type": "debugpy", + "request": "attach", + "connect": {"host": "localhost", "port": 5678}, + "justMyCode": True, + "autoAttachChildProcesses": True, +} + + +@click.group() +def cli(): + pass + + +@cli.group(help="Commands related to debugging Metaflow flows.") +def debug(): + pass + + +@debug.group(help="VSCode debugger integration.") +def vscode(): + pass + + +@vscode.command( + "install-config", + help="Install VSCode debug configuration for attaching to Metaflow tasks.", +) +@click.option( + "--base-port", + default=5678, + type=int, + show_default=True, + help="Port number for the debugpy attach configuration.", +) +@click.option( + "--dir", + "target_dir", + default=".", + type=click.Path(), + help="Workspace root directory where .vscode/ will be created.", +) +@click.option( + "--overwrite", + is_flag=True, + default=False, + help="Overwrite existing launch.json instead of merging.", +) +@click.option( + "--remote-root", + default=None, + type=str, + help="Remote container root (e.g. /root/metaflow). Adds pathMappings for remote debugging.", +) +def install_config(base_port, target_dir, overwrite, remote_root): + target_dir = os.path.abspath(target_dir) + vscode_dir = os.path.join(target_dir, ".vscode") + launch_path = os.path.join(vscode_dir, "launch.json") + + our_config = dict(METAFLOW_ATTACH_CONFIG) + our_config["connect"] = {"host": "localhost", "port": base_port} + + if remote_root is not None: + import metaflow + + # Parent dir containing the metaflow package (e.g. site-packages or repo root). + local_metaflow_src = os.path.dirname(os.path.dirname(metaflow.__file__)) + our_config["pathMappings"] = [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": remote_root, + }, + { + "localRoot": local_metaflow_src, + "remoteRoot": remote_root + "/.mf_code", + }, + ] + + if not os.path.isdir(vscode_dir): + os.makedirs(vscode_dir) + + if os.path.exists(launch_path) and not overwrite: + with open(launch_path, "r") as f: + try: + existing = json.load(f) + except json.JSONDecodeError: + click.echo( + "Warning: existing launch.json is not valid JSON. " + "Use --overwrite to replace it.", + err=True, + ) + return + + configs = existing.get("configurations", []) + + # Check if our config already exists + for i, cfg in enumerate(configs): + if cfg.get("name") == "Metaflow: Attach": + if cfg != our_config: + configs[i] = our_config + existing["configurations"] = configs + with open(launch_path, "w") as f: + json.dump(existing, f, indent=4) + f.write("\n") + click.echo("Updated 'Metaflow: Attach' config in %s" % launch_path) + else: + click.echo( + "'Metaflow: Attach' config already up to date in %s" + % launch_path + ) + return + + # Merge: append our config + configs.append(our_config) + existing["configurations"] = configs + with open(launch_path, "w") as f: + json.dump(existing, f, indent=4) + f.write("\n") + click.echo("Added 'Metaflow: Attach' config to existing %s" % launch_path) + else: + launch_json = { + "version": "0.2.0", + "configurations": [our_config], + } + with open(launch_path, "w") as f: + json.dump(launch_json, f, indent=4) + f.write("\n") + click.echo( + "Created %s with 'Metaflow: Attach' config (port %d)" + % (launch_path, base_port) + ) diff --git a/metaflow/cmd/main_cli.py b/metaflow/cmd/main_cli.py index c552b32a24b..7c2c1d76555 100644 --- a/metaflow/cmd/main_cli.py +++ b/metaflow/cmd/main_cli.py @@ -68,6 +68,7 @@ def status(): ("tutorials", ".tutorials_cmd.cli"), ("develop", ".develop.cli"), ("code", ".code.cli"), + ("debug", ".debug_cli.cli"), ] process_cmds(globals()) diff --git a/metaflow/extension_support/cmd.py b/metaflow/extension_support/cmd.py index 1d9e457a152..1cc704706e5 100644 --- a/metaflow/extension_support/cmd.py +++ b/metaflow/extension_support/cmd.py @@ -49,7 +49,7 @@ def process_cmds(module_globals): # override metaflow core) for name, class_path in _all_cmds: _ext_debug(" Adding command '%s' from '%s'" % (name, class_path)) - _all_cmds_dict[name] = class_path + _all_cmds_dict.setdefault(name, []).append(class_path) # Resolve the ENABLED_CMD variable. The rules are the following: # - if ENABLED_CMD is non None, it means it was either set directly by the user @@ -83,33 +83,65 @@ def resolve_cmds(): to_return = [] for name in set_of_commands: - class_path = _all_cmds_dict.get(name, None) - if class_path is None: + class_paths = _all_cmds_dict.get(name, None) + if class_paths is None: raise ValueError( "Configuration requested command '%s' but no such command is available" % name ) - path, cls_name = class_path.rsplit(".", 1) - try: - cmd_module = importlib.import_module(path) - except ImportError: - raise ValueError("Cannot locate command '%s' at '%s'" % (name, path)) - - cls = getattr(cmd_module, cls_name, None) - if cls is None: - raise ValueError( - "Cannot locate '%s' class for command at '%s'" % (cls_name, path) - ) - all_cmds = list(cls.commands) - if len(all_cmds) > 1: - raise ValueError("%s defines more than one command -- use a group" % path) - if all_cmds[0] != name: - raise ValueError( - "%s: expected name to be '%s' but got '%s' instead" - % (path, name, all_cmds[0]) + + def _load_cmd_cls(class_path, name): + path, cls_name = class_path.rsplit(".", 1) + try: + cmd_module = importlib.import_module(path) + except ImportError: + raise ValueError("Cannot locate command '%s' at '%s'" % (name, path)) + cls = getattr(cmd_module, cls_name, None) + if cls is None: + raise ValueError( + "Cannot locate '%s' class for command at '%s'" % (cls_name, path) + ) + all_cmds = list(cls.commands) + if len(all_cmds) > 1: + raise ValueError( + "%s defines more than one command -- use a group" % path + ) + if all_cmds[0] != name: + raise ValueError( + "%s: expected name to be '%s' but got '%s' instead" + % (path, name, all_cmds[0]) + ) + return cls + + if len(class_paths) == 1: + cls = _load_cmd_cls(class_paths[0], name) + to_return.append(cls) + _ext_debug(" Added command '%s' from '%s'" % (name, class_paths[0])) + else: + # Multiple providers for the same command name — merge subcommands. + # The last entry (extension) is the base; earlier entries contribute + # subcommands that don't collide with the base. + # This is effectively overriding anything in the previous extensions + # with later extensions. + base_cls = _load_cmd_cls(class_paths[-1], name) + base_group = base_cls.commands[name] + + for earlier_path in class_paths[:-1]: + earlier_cls = _load_cmd_cls(earlier_path, name) + earlier_group = earlier_cls.commands[name] + for cmd_name, cmd in earlier_group.commands.items(): + if cmd_name not in base_group.commands: + base_group.add_command(cmd, cmd_name) + _ext_debug( + " Merged subcommand '%s' into '%s' from '%s'" + % (cmd_name, name, earlier_path) + ) + + to_return.append(base_cls) + _ext_debug( + " Added merged command '%s' (base from '%s', %d providers)" + % (name, class_paths[-1], len(class_paths)) ) - to_return.append(cls) - _ext_debug(" Added command '%s' from '%s'" % (name, class_path)) return to_return diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index bf9fb7bf90b..cfd31a892a5 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -55,6 +55,7 @@ ("airflow_internal", ".airflow.airflow_decorator.AirflowInternalDecorator"), ("pypi", ".pypi.pypi_decorator.PyPIStepDecorator"), ("conda", ".pypi.conda_decorator.CondaStepDecorator"), + ("debugger", ".debugger_step_decorator.DebuggerStepDecorator"), ] # Add new flow decorators here diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 1825d7c4cae..17694f5acf2 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -315,6 +315,11 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) + # Forward debugger env vars to the remote container. + for key, val in os.environ.items(): + if key.startswith("METAFLOW_DEBUGPY_"): + env[key] = val + if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" diff --git a/metaflow/plugins/debugger_step_decorator.py b/metaflow/plugins/debugger_step_decorator.py new file mode 100644 index 00000000000..a21050b2d28 --- /dev/null +++ b/metaflow/plugins/debugger_step_decorator.py @@ -0,0 +1,603 @@ +""" +debugger_step_decorator.py — VSCode/debugpy integration for Metaflow tasks. + +Enables interactive debugging of every task subprocess. Two modes depending on +whether the flow runs locally or on remote compute (Kubernetes/Batch): + +LOCAL MODE (mode="connect") +=========================== + + Developer machine + ┌──────────────────────────────────────────────┐ + │ │ + │ VSCode ◄──DAP──► debugpy adapter (:5678) │ + │ ▲ ▲ │ + │ │ │ │ + │ (pydevd (pydevd │ + │ protocol) protocol) │ + │ │ │ │ + │ task (pid1) task (pid2) │ + │ │ + └──────────────────────────────────────────────┘ + + 1. runtime_init: starts the debugpy adapter via debugpy.listen(). + 2. runtime_step_cli: passes adapter host/port to each task via env vars. + 3. task_pre_step (_task_connect): each task calls debugpy.connect() back to + the adapter, which notifies VSCode to auto-attach a new debug session. + + +REMOTE MODE (mode="listen") +=========================== + + Developer machine Remote container + ┌─────────────────────────────┐ ┌─────────────────┐ + │ │ │ │ + │ VSCode ◄─DAP─► adapter │ │ pydevd (:5678)│ + │ ▲ │ │ ▲ │ + │ (pydevd │ │ │ │ + │ protocol) │ │ task code │ + │ │ │ │ │ │ + │ bridge ◄───────╂── TCP ─────╂─────┘ │ + │ ▲ │ │ │ + │ │ │ │ │ + │ callback server │◄── TCP ────╂── callback │ + │ (ephemeral port) │ (endpoint │ (host:port) │ + │ │ JSON) │ │ + └─────────────────────────────┘ └─────────────────┘ + + 1. runtime_init: starts debugpy.listen() + a callback server on an + ephemeral port. Passes callback host/port to tasks via env vars. + + 2. task_pre_step (_task_listen): the remote task: + a. Pre-binds a server socket on base_port. + b. Sends its {host, port} as JSON to the callback server. + c. Starts pydevd.settrace() in server mode (reusing the pre-bound socket). + + 3. _handle_callback (on the developer machine, per task): + a. Reads the task's {host, port} from the callback connection. + b. Opens a "bridge" socket to the adapter's internal pydevd server. + c. Completes the 2-message DAP handshake that the adapter expects + (pydevdAuthorize + pydevdSystemInfo) so the adapter registers + a new debug session. + d. Connects to the remote task's pydevd and pipes traffic in both + directions: adapter <--bridge--> remote pydevd. + + This makes each remote task appear as a local child process to the adapter, + so VSCode auto-attaches seamlessly. + + +LIFECYCLE (StepDecorator hooks) +=============================== + + step_init → verify debugpy is importable + runtime_init → start adapter + callback server, print banner + runtime_step_cli → inject env vars for the task subprocess + task_pre_step → connect or listen depending on mode + runtime_finished → reset class-level state +""" + +import json +import os +import socket +import sys +import threading + +from metaflow.decorators import StepDecorator +from metaflow.exception import MetaflowException + +_ENV_ADAPTER_HOST = "METAFLOW_DEBUGPY_ADAPTER_HOST" +_ENV_ADAPTER_PORT = "METAFLOW_DEBUGPY_ADAPTER_PORT" +_ENV_PARENT_PID = "METAFLOW_DEBUGPY_PARENT_PID" +_ENV_WAIT_FOR_CLIENT = "METAFLOW_DEBUGPY_WAIT_FOR_CLIENT" +_ENV_DEBUG_MODE = "METAFLOW_DEBUGPY_MODE" +_ENV_BASE_PORT = "METAFLOW_DEBUGPY_BASE_PORT" +_ENV_CALLBACK_HOST = "METAFLOW_DEBUGPY_CALLBACK_HOST" +_ENV_CALLBACK_PORT = "METAFLOW_DEBUGPY_CALLBACK_PORT" +_ENV_ACCESS_TOKEN = "METAFLOW_DEBUGPY_ACCESS_TOKEN" + +_LOG_PREFIX = "[DEBUGGER]" + + +def _log(msg, *args): + if args: + msg = msg % args + sys.stderr.write("%s %s\n" % (_LOG_PREFIX, msg)) + sys.stderr.flush() + + +def _read_dap_message(sock): + """Read one DAP-framed JSON message from *sock*.""" + buf = b"" + while b"\r\n\r\n" not in buf: + chunk = sock.recv(1) + if not chunk: + raise ConnectionError("socket closed while reading DAP header") + buf += chunk + header, _ = buf.split(b"\r\n\r\n", 1) + content_length = None + for line in header.split(b"\r\n"): + if line.lower().startswith(b"content-length:"): + content_length = int(line.split(b":", 1)[1].strip()) + break + if content_length is None: + raise ValueError("DAP message missing Content-Length header") + body = b"" + while len(body) < content_length: + chunk = sock.recv(content_length - len(body)) + if not chunk: + raise ConnectionError("socket closed while reading DAP body") + body += chunk + return json.loads(body) + + +def _write_dap_message(sock, body): + """Write a DAP-framed JSON message to *sock*.""" + payload = json.dumps(body, separators=(",", ":")).encode("utf-8") + header = ("Content-Length: %d\r\n\r\n" % len(payload)).encode("utf-8") + sock.sendall(header + payload) + + +def _pipe(src, dst): + """Copy bytes from *src* to *dst* until EOF or error.""" + try: + while True: + data = src.recv(65536) + if not data: + break + dst.sendall(data) + except OSError: + pass + finally: + try: + dst.shutdown(socket.SHUT_WR) + except OSError: + pass + + +_fake_pid_counter = 100000 +_fake_pid_lock = threading.Lock() + + +def _next_fake_pid(): + global _fake_pid_counter + with _fake_pid_lock: + _fake_pid_counter += 1 + return _fake_pid_counter + + +def _handle_callback(conn, adapter_info): + """Handle one callback from a remote task. + + Reads the task's endpoint, connects a bridge to the local adapter's + internal pydevd server, completes the 2-message DAP handshake + (pydevdAuthorize + pydevdSystemInfo), then forwards all traffic + between the adapter and the remote task's pydevd. + """ + try: + # Read endpoint JSON from the task. + data = b"" + while True: + chunk = conn.recv(4096) + if not chunk: + break + data += chunk + conn.close() + endpoint = json.loads(data) + task_host = endpoint["host"] + task_port = int(endpoint["port"]) + + # Connect bridge to adapter's internal pydevd server. + bridge = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + bridge.connect((adapter_info["host"], adapter_info["port"])) + + # DAP handshake — respond to pydevdAuthorize + pydevdSystemInfo. + seq = 0 + + def _next_seq(): + nonlocal seq + seq += 1 + return seq + + msg = _read_dap_message(bridge) + if msg.get("command") != "pydevdAuthorize": + raise RuntimeError("expected pydevdAuthorize, got %s" % msg.get("command")) + # adapter.access_token is always None when spawned via debugpy.listen(). + _write_dap_message( + bridge, + { + "seq": _next_seq(), + "type": "response", + "request_seq": msg["seq"], + "success": True, + "command": "pydevdAuthorize", + "body": {"clientAccessToken": None}, + }, + ) + + msg = _read_dap_message(bridge) + if msg.get("command") != "pydevdSystemInfo": + raise RuntimeError("expected pydevdSystemInfo, got %s" % msg.get("command")) + _write_dap_message( + bridge, + { + "seq": _next_seq(), + "type": "response", + "request_seq": msg["seq"], + "success": True, + "command": "pydevdSystemInfo", + "body": { + "python": { + "version": "3.11.0", + "implementation": {"name": "cpython", "version": "3.11.0"}, + }, + "platform": {"name": "linux"}, + "process": { + "pid": _next_fake_pid(), + "ppid": adapter_info["parent_pid"], + "executable": "python", + "bitness": 64, + }, + }, + }, + ) + + # Forward all subsequent traffic: adapter <-> remote task's pydevd. + remote = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + remote.connect((task_host, task_port)) + threading.Thread(target=_pipe, args=(bridge, remote), daemon=True).start() + threading.Thread(target=_pipe, args=(remote, bridge), daemon=True).start() + + except Exception as exc: + _log("Bridge setup failed: %s", exc) + + +def _start_callback_server(adapter_info): + """Start a TCP server that accepts callback connections from remote tasks. + + Returns the port of the listening server. + """ + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(("0.0.0.0", 0)) + server.listen(16) + _, port = server.getsockname() + + def _accept_loop(): + while True: + try: + conn, _ = server.accept() + except OSError: + break + threading.Thread( + target=_handle_callback, + args=(conn, adapter_info), + daemon=True, + ).start() + + threading.Thread(target=_accept_loop, daemon=True).start() + return port + + +def _create_listen_socket(host, port): + """Create a TCP server socket bound to *host*:*port*.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(1) + return sock + + +class DebuggerStepDecorator(StepDecorator): + """ + Step decorator that enables interactive debugging of Metaflow tasks via debugpy. + + **Local tasks** connect back to the runtime's debugpy adapter, triggering + VSCode to auto-attach a new debug session for each task. + + **Remote tasks** (Kubernetes/Batch) start their own pydevd listener + and call back to the runtime, which bridges the connection through the + local adapter so VSCode auto-attaches seamlessly. + + Usage:: + + python flow.py run --with debugger + python flow.py run --with debugger:base_port=9000 + python flow.py run --with debugger:wait_for_client=False + + Or as a step annotation:: + + @debugger + @step + def my_step(self): + ... + """ + + name = "debugger" + defaults = { + "base_port": "5678", + "wait_for_client": "True", + "enabled": "True", + "host": "auto", + } + + _REMOTE_COMPUTE_DECORATORS = {"titus", "kubernetes", "batch"} + + _adapter_info = None + _banner_printed = False + _is_remote = False + _callback_host = None + _callback_port = None + + @property + def _is_enabled(self): + return self.attributes["enabled"].lower() == "true" + + @property + def _wait_for_client(self): + return self.attributes["wait_for_client"].lower() == "true" + + def _has_remote_compute(self, graph): + return any( + deco.name in self._REMOTE_COMPUTE_DECORATORS + for node in graph + for deco in node.decorators + ) + + @staticmethod + def _get_routable_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + finally: + s.close() + + def step_init( + self, flow, graph, step_name, decorators, environment, flow_datastore, logger + ): + if not self._is_enabled: + return + try: + import debugpy # noqa: F401 + except ImportError: + raise MetaflowException( + "The @debugger decorator requires the 'debugpy' package. " + "Install it with: pip install debugpy or, add it using @pypi or @conda" + ) + + def runtime_init(self, flow, graph, package, run_id): + if not self._is_enabled: + return + + if self.__class__._adapter_info is not None: + self._print_banner() + return + + import debugpy + + base_port = int(self.attributes["base_port"]) + host_attr = self.attributes["host"] + + is_remote = self._has_remote_compute(graph) + self.__class__._is_remote = is_remote + + if host_attr == "auto": + listen_host = "0.0.0.0" if is_remote else "127.0.0.1" + else: + listen_host = host_attr + + debugpy.listen((listen_host, base_port)) + + if host_attr != "auto": + connect_host = host_attr + elif is_remote: + connect_host = self._get_routable_ip() + else: + connect_host = "127.0.0.1" + + # Read adapter internal server info (available after debugpy.listen). + from pydevd import SetupHolder + + setup = SetupHolder.setup + if setup is None: + raise MetaflowException( + "Failed to initialize debugpy: could not read adapter info" + ) + + adapter_info = { + "mode": "listen" if is_remote else "connect", + "host": setup["client"], + "port": int(setup["port"]), + "parent_pid": os.getpid(), + "connect_host": connect_host, + } + + if is_remote: + adapter_info["access_token"] = setup.get("access-token") + cb_port = _start_callback_server(adapter_info) + self.__class__._callback_host = connect_host + self.__class__._callback_port = cb_port + + self.__class__._adapter_info = adapter_info + self._print_banner() + + if self._wait_for_client: + _log("Waiting for VSCode to attach on port %d ...", base_port) + debugpy.wait_for_client() + + def _print_banner(self): + if self.__class__._banner_printed: + return + self.__class__._banner_printed = True + + is_remote = self.__class__._is_remote + base_port = int(self.attributes["base_port"]) + lines = [ + "", + "=" * 60, + "%s debugpy is enabled for this run." % _LOG_PREFIX, + ] + if is_remote: + lines.extend( + [ + "%s Remote compute detected -- seamless bridge mode." % _LOG_PREFIX, + "%s Ensure debugpy is available in the container" + " (e.g. @pypi(packages={'debugpy': '...'}))." % _LOG_PREFIX, + ] + ) + else: + lines.extend( + [ + "%s Attach VSCode to localhost:%d to start debugging." + % (_LOG_PREFIX, base_port), + ] + ) + if self._wait_for_client: + lines.append( + "%s Tasks will WAIT for a debugger client to attach." % _LOG_PREFIX + ) + lines.append("=" * 60) + sys.stderr.write("\n".join(lines) + "\n") + + def runtime_step_cli( + self, cli_args, retry_count, max_user_code_retries, ubf_context + ): + if not self._is_enabled: + return + + info = self.__class__._adapter_info + if info is None: + return + + mode = info["mode"] + cli_args.env[_ENV_DEBUG_MODE] = mode + + if mode == "connect": + cli_args.env[_ENV_ADAPTER_HOST] = str(info["connect_host"]) + cli_args.env[_ENV_ADAPTER_PORT] = str(info["port"]) + cli_args.env[_ENV_PARENT_PID] = str(info["parent_pid"]) + elif mode == "listen": + cli_args.env[_ENV_BASE_PORT] = self.attributes["base_port"] + cli_args.env[_ENV_CALLBACK_HOST] = str(self.__class__._callback_host) + cli_args.env[_ENV_CALLBACK_PORT] = str(self.__class__._callback_port) + if info.get("access_token"): + cli_args.env[_ENV_ACCESS_TOKEN] = info["access_token"] + + if self._wait_for_client: + cli_args.env[_ENV_WAIT_FOR_CLIENT] = "1" + + task = cli_args.task + + def task_pre_step( + self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_user_code_retries, + ubf_context, + inputs, + ): + if not self._is_enabled: + return + + mode = os.environ.get(_ENV_DEBUG_MODE, "") + if not mode: + return + + wait = os.environ.get(_ENV_WAIT_FOR_CLIENT) == "1" + task_key = "%s/%s/%s" % (flow.name, step_name, task_id) + + try: + if mode == "listen": + self._task_listen(task_key, wait) + elif mode == "connect": + self._task_connect(task_key, wait) + except Exception as e: + _log("%s debugger setup failed: %s", task_key, e) + import traceback + + traceback.print_exc(file=sys.stderr) + + def _task_listen(self, task_key, wait): + """Remote mode: start raw pydevd in server mode. + + We use pydevd directly (not debugpy.listen) because the local adapter + connects through a bridge speaking the internal pydevd protocol. + """ + base_port = int(os.environ[_ENV_BASE_PORT]) + + import debugpy._vendored.force_pydevd # noqa: F401 + import pydevd + + # Pre-create the listening socket before sending the callback to + # eliminate the race where the runtime tries to connect before + # pydevd.settrace() has bound the port. + server_sock = _create_listen_socket("", base_port) + + # Report our endpoint back to the runtime. + callback_host = os.environ[_ENV_CALLBACK_HOST] + callback_port = int(os.environ[_ENV_CALLBACK_PORT]) + container_ip = socket.gethostbyname(socket.gethostname()) + payload = json.dumps({"host": container_ip, "port": base_port}) + cb = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cb.settimeout(30) + cb.connect((callback_host, callback_port)) + cb.sendall(payload.encode("utf-8")) + cb.shutdown(socket.SHUT_WR) + cb.close() + + # Monkey-patch pydevd.start_server to reuse our pre-bound socket. + # pydevd.py imports start_server at module level, so we must patch + # the reference in the pydevd module, not in pydevd_comm. + _orig = pydevd.start_server + + def _patched(port): + conn, _ = server_sock.accept() + server_sock.close() + return conn + + pydevd.start_server = _patched + try: + kwargs = dict( + host="", + port=base_port, + suspend=False, + block_until_connected=True, + wait_for_ready_to_run=wait, + protocol="dap", + ) + access_token = os.environ.get(_ENV_ACCESS_TOKEN) + if access_token: + kwargs["access_token"] = access_token + pydevd.settrace(**kwargs) + finally: + pydevd.start_server = _orig + + @staticmethod + def _task_connect(task_key, wait): + """Local mode: connect back to the parent adapter.""" + import debugpy + + adapter_host = os.environ[_ENV_ADAPTER_HOST] + adapter_port = int(os.environ[_ENV_ADAPTER_PORT]) + parent_pid = int(os.environ[_ENV_PARENT_PID]) + + debugpy.connect( + (adapter_host, adapter_port), + parent_session_pid=parent_pid, + ) + + if wait: + debugpy.wait_for_client() + + def runtime_finished(self, exception): + self.__class__._adapter_info = None + self.__class__._banner_printed = False + self.__class__._is_remote = False + self.__class__._callback_host = None + self.__class__._callback_port = None diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index e15f7b06cb9..dc0e00d0a97 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -223,6 +223,11 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys()) env.update(split_vars) + # Forward debugger env vars to the remote container. + for key, val in os.environ.items(): + if key.startswith("METAFLOW_DEBUGPY_"): + env[key] = val + if num_parallel is not None and num_parallel <= 1: raise KubernetesException( "Using @parallel with `num_parallel` <= 1 is not supported with "