|
50 | 50 | operation_name = None # global used in formatted logging |
51 | 51 |
|
52 | 52 |
|
| 53 | +def set_dist_train_config(rank, nranks, step_name, port=9888): |
| 54 | + """ |
| 55 | + Set distributed training envs for general uses. |
| 56 | + For Tensorflow: TF_CONFIG is configured. |
| 57 | + For Pytorch: MASTER_ADDR and MASTER_PORT is configured. |
| 58 | + For general use cases: NRANKS and RANK is configured. |
| 59 | +
|
| 60 | + TODO: this function is Argo specific, should add Tekton support. |
| 61 | + """ |
| 62 | + from kubernetes import client, config |
| 63 | + |
| 64 | + wf_id = os.getenv("WORKFLOW_ID") |
| 65 | + ns = os.getenv("KFP_NAMESPACE") |
| 66 | + if not wf_id or not ns: |
| 67 | + raise ValueError("WORKFLOW_ID and KFP_NAMESPACE env must be set in the workflow pod!") |
| 68 | + |
| 69 | + config.load_incluster_config() |
| 70 | + api = client.CustomObjectsApi() |
| 71 | + |
| 72 | + worker_started = 0 |
| 73 | + while worker_started != nranks: |
| 74 | + resource = api.get_namespaced_custom_object( |
| 75 | + group="argoproj.io", |
| 76 | + version="v1alpha1", |
| 77 | + name=wf_id, |
| 78 | + namespace=ns, |
| 79 | + plural="workflows", |
| 80 | + ) |
| 81 | + if not resource.get("status"): |
| 82 | + time.sleep(2) |
| 83 | + continue |
| 84 | + if not resource["status"].get("nodes"): |
| 85 | + time.sleep(2) |
| 86 | + continue |
| 87 | + |
| 88 | + nodes = resource["status"]["nodes"] |
| 89 | + workers_spec = [] |
| 90 | + for nk in nodes: |
| 91 | + node_info = nodes[nk] |
| 92 | + OpUtil.log_operation_info( |
| 93 | + "kfpdist: searching for {}, curr node: {}, templateName: {}, type: {}".format( |
| 94 | + step_name, nk, node_info["templateName"], node_info["type"] |
| 95 | + ) |
| 96 | + ) |
| 97 | + if node_info["templateName"] == step_name and node_info["type"] == "Pod": |
| 98 | + podid = node_info["id"] |
| 99 | + for input_param in node_info["inputs"]["parameters"]: |
| 100 | + if input_param["name"].find("loop-item") >= 0: |
| 101 | + # FIXME: argo parameter with "loop-item" is the rank. |
| 102 | + curr_rank = int(input_param["value"]) |
| 103 | + break |
| 104 | + v1 = client.CoreV1Api() |
| 105 | + podinfo = v1.read_namespaced_pod(podid, ns) |
| 106 | + if podinfo.status.pod_ip: |
| 107 | + workers_spec.append((curr_rank, "%s:%d" % (podinfo.status.pod_ip, port))) |
| 108 | + worker_started = len(workers_spec) |
| 109 | + time.sleep(2) |
| 110 | + |
| 111 | + workers_spec.sort(key=lambda item: item[0]) |
| 112 | + workers_spec_list = [i[1] for i in workers_spec] |
| 113 | + # set TF_CONFIG env for tf dist train |
| 114 | + os.environ["TF_CONFIG"] = json.dumps( |
| 115 | + {"cluster": {"worker": workers_spec_list}, "task": {"type": "worker", "index": rank}} |
| 116 | + ) |
| 117 | + OpUtil.log_operation_info("Setting TF_CONFIG: %s" % os.environ["TF_CONFIG"]) |
| 118 | + os.environ["MASTER_ADDR"] = workers_spec[0][1].split(":")[0] |
| 119 | + os.environ["MASTER_PORT"] = workers_spec[0][1].split(":")[1] |
| 120 | + OpUtil.log_operation_info( |
| 121 | + "Setting MASTER_ADDR: {}, MASTER_PORT: {}".format(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"]) |
| 122 | + ) |
| 123 | + OpUtil.log_operation_info("Setting RANK: {}, NRANKS: {}".format(os.environ["RANK"], os.environ["NRANKS"])) |
| 124 | + |
| 125 | + |
53 | 126 | class FileOpBase(ABC): |
54 | 127 | """Abstract base class for file-based operations""" |
55 | 128 |
|
@@ -724,6 +797,22 @@ def main(): |
724 | 797 | ) |
725 | 798 | # Setup packages and gather arguments |
726 | 799 | input_params = OpUtil.parse_arguments(sys.argv[1:]) |
| 800 | + |
| 801 | + if os.getenv("RANK"): |
| 802 | + op_name = os.getenv("ELYRA_OP_NAME") |
| 803 | + if not op_name: |
| 804 | + raise ValueError( |
| 805 | + "env ELYRA_OP_NAME is not set. please check whether elyra version is matching bootstrapper.py" |
| 806 | + ) |
| 807 | + |
| 808 | + # FIXME: operation name will be updated by kfp, replace these chars for matching. |
| 809 | + op_name = op_name.replace("_", "-") |
| 810 | + rank = int(os.getenv("RANK")) |
| 811 | + nranks = int(os.getenv("NRANKS")) |
| 812 | + if not nranks: |
| 813 | + raise ValueError("rank argument setted but no NRANKS env found!") |
| 814 | + set_dist_train_config(rank, nranks, op_name, port=9888) |
| 815 | + |
727 | 816 | OpUtil.log_operation_info("starting operation") |
728 | 817 | t0 = time.time() |
729 | 818 | OpUtil.package_install(user_volume_path=input_params.get("user-volume-path")) |
|
0 commit comments