Skip to content

Commit 36d216f

Browse files
sutaakarRHRolunFiona-Waters
authored
Add KFTO Distributed training notebook and scripts (#103)
* Add KFTO Distributed training notebook and scripts Signed-off-by: Karel Suta <[email protected]> * Apply suggestions from code review Co-authored-by: Fiona Waters <[email protected]> * Addressed feedback from PR * Adjust KFTO PyTorchJob waiting to be more interactive --------- Signed-off-by: Karel Suta <[email protected]> Co-authored-by: RHRolun <[email protected]> Co-authored-by: Fiona Waters <[email protected]>
1 parent c067995 commit 36d216f

File tree

7 files changed

+748
-3
lines changed

7 files changed

+748
-3
lines changed

9_distributed_training_kfto.ipynb

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Training the Fraud Detection model with Kubeflow Training Operator"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"The example fraud detection model is very small and quickly trained. However, for many large models, training requires multiple GPUs and often multiple machines. In this notebook, you learn how to train a model by using Kubeflow Training Operator on OpenShift AI to scale out the model training. You use the Training Operator SDK to create a PyTorchJob executing the provided training script."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {},
20+
"source": [
21+
"### Preparing Training Operator SDK\n",
22+
"\n",
23+
"Training operator SDK is not available by default on Tensorflow notebooks.Therefore it needs to be installed first."
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"%pip install -qqU kubeflow-training==1.9.2"
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"### Preparing the data\n",
40+
"\n",
41+
"Normally, the training data for your model would be available in a shared location. For this example, the data is local. You must upload it to your object storage so that you can see how data loading from a shared data source works. Training data is downloaded via the training script and distributed among workers by DistributedSampler."
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"metadata": {
48+
"tags": []
49+
},
50+
"outputs": [],
51+
"source": [
52+
"import sys\n",
53+
"sys.path.append('./utils')\n",
54+
"\n",
55+
"import utils.s3\n",
56+
"\n",
57+
"utils.s3.upload_directory_to_s3(\"data\", \"data\")\n",
58+
"print(\"---\")\n",
59+
"utils.s3.list_objects(\"data\")"
60+
]
61+
},
62+
{
63+
"cell_type": "markdown",
64+
"metadata": {},
65+
"source": [
66+
"### Authenticate to the cluster by using the OpenShift console login\n",
67+
"\n",
68+
"Training Operator SDK requires authenticated access to the OpenShift cluster to create PyTorchJobs. The easiest way to get access details is through the OpenShift web console. \n",
69+
" \n",
70+
"\n",
71+
"1. To generate the command, select **Copy login command** from the username drop-down menu at the top right of the web console.\n",
72+
"\n",
73+
" <figure>\n",
74+
" <img src=\"./assets/copy-login.png\" alt=\"copy login\" >\n",
75+
" <figure/>\n",
76+
"\n",
77+
"2. Click **Display token**.\n",
78+
"\n",
79+
"3. Below **Log in with this token**, take note of the parameters for token and server.\n",
80+
" For example:\n",
81+
" ```\n",
82+
" oc login --token=sha256~LongString --server=https://api.your-cluster.domain.com:6443\n",
83+
" ``` \n",
84+
" - token: `sha256~LongString`\n",
85+
" - server: `https://api.your-cluster.domain.com:6443`\n",
86+
" \n",
87+
"4. In the following code cell replace the token and server values with the values that you noted in Step 3.\n",
88+
" For example:\n",
89+
" ```\n",
90+
" api_server = \"https://api.your-cluster.domain.com:6443\"\n",
91+
" token = \"sha256~LongString\"\n",
92+
" ```\n"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"metadata": {
99+
"tags": []
100+
},
101+
"outputs": [],
102+
"source": [
103+
"from kubernetes import client\n",
104+
"\n",
105+
"api_server = \"https://XXXX\"\n",
106+
"token = \"sha256~XXXX\"\n",
107+
"\n",
108+
"configuration = client.Configuration()\n",
109+
"configuration.host = api_server\n",
110+
"configuration.api_key = {\"authorization\": f\"Bearer {token}\"}\n",
111+
"# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA\n",
112+
"#configuration.verify_ssl = False"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"metadata": {
118+
"tags": []
119+
},
120+
"source": [
121+
"## Running the distributed training"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"metadata": {},
127+
"source": [
128+
"### Initialize Training client\n",
129+
"\n",
130+
"Initialize Training client using provided user credentials."
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": null,
136+
"metadata": {},
137+
"outputs": [],
138+
"source": [
139+
"from kubeflow.training import TrainingClient\n",
140+
"\n",
141+
"client = TrainingClient(client_configuration=configuration)"
142+
]
143+
},
144+
{
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"### Create PyTorchJob\n",
149+
"\n",
150+
"Submit PyTorchJob using Training Operator SDK client.\n",
151+
"\n",
152+
"Training script is imported from `kfto-scripts` folder.\n",
153+
"\n",
154+
"Training script loads and distributes training dataset among nodes, performs distributed training, evaluation using test dataset, exports the trained model to onnx format and uploads it to the S3 bucket specified in provided connection.\n",
155+
"\n",
156+
"Important note - If Kueue component is enabled in RHOAI then you must create all Kueue related resources (ResourceFlavor, ClusterQueue and LocalQueue) and provide LocalQueue name in the script below, also uncomment label declaration in create_job function."
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"metadata": {
163+
"tags": []
164+
},
165+
"outputs": [],
166+
"source": [
167+
"import sys\n",
168+
"import os\n",
169+
"sys.path.append(\"./kfto-scripts\") # needed to make training function available in the notebook\n",
170+
"from train_pytorch_cpu import train_func\n",
171+
"from kubernetes.client import (\n",
172+
" V1EnvVar,\n",
173+
" V1EnvVarSource,\n",
174+
" V1SecretKeySelector\n",
175+
")\n",
176+
"\n",
177+
"# Job name serves as unique identifier to retrieve job related informations using SDK\n",
178+
"job_name = \"fraud-detection\"\n",
179+
"\n",
180+
"# Specifies Kueue LocalQueue name.\n",
181+
"# If Kueue component is enabled then you must create all Kueue related resources (ResourceFlavor, ClusterQueue and LocalQueue) and provide LocalQueue name here.\n",
182+
"local_queue_name = \"local-queue\"\n",
183+
"\n",
184+
"client.create_job(\n",
185+
" job_kind=\"PyTorchJob\",\n",
186+
" name=job_name,\n",
187+
" train_func=train_func,\n",
188+
" num_workers=2,\n",
189+
" num_procs_per_worker=\"1\",\n",
190+
" resources_per_worker={\n",
191+
" \"memory\": \"4Gi\",\n",
192+
" \"cpu\": 1,\n",
193+
" },\n",
194+
" base_image=\"quay.io/modh/training:py311-cuda124-torch251\",\n",
195+
" # Uncomment the following line to add the queue-name label if Kueue component is enabled in RHOAI and all Kueue related resources are created. Replace `local_queue_name` with the name of your LocalQueue\n",
196+
"# labels={\"kueue.x-k8s.io/queue-name\": local_queue_name},\n",
197+
" env_vars=[\n",
198+
" V1EnvVar(name=\"AWS_ACCESS_KEY_ID\", value=os.environ.get(\"AWS_ACCESS_KEY_ID\")),\n",
199+
" V1EnvVar(name=\"AWS_S3_BUCKET\", value=os.environ.get(\"AWS_S3_BUCKET\")),\n",
200+
" V1EnvVar(name=\"AWS_S3_ENDPOINT\", value=os.environ.get(\"AWS_S3_ENDPOINT\")),\n",
201+
" V1EnvVar(name=\"AWS_SECRET_ACCESS_KEY\", value=os.environ.get(\"AWS_SECRET_ACCESS_KEY\")),\n",
202+
" ],\n",
203+
" packages_to_install=[\n",
204+
" \"s3fs\",\n",
205+
" \"boto3\",\n",
206+
" \"scikit-learn\",\n",
207+
" \"onnx\",\n",
208+
" ],\n",
209+
")"
210+
]
211+
},
212+
{
213+
"cell_type": "markdown",
214+
"metadata": {
215+
"tags": []
216+
},
217+
"source": [
218+
"### Query important job information"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": null,
224+
"metadata": {
225+
"tags": []
226+
},
227+
"outputs": [],
228+
"source": [
229+
"import time\n",
230+
"\n",
231+
"\n",
232+
"# Wait until the job finishes\n",
233+
"print(f\"PyTorchJob '{job_name}' is running.\", end='')\n",
234+
"while True:\n",
235+
" try:\n",
236+
" if client.is_job_running(name=job_name):\n",
237+
" print(\".\", end='')\n",
238+
" elif client.is_job_succeeded(name=job_name):\n",
239+
" print(\".\")\n",
240+
" print([x.message for x in client.get_job_conditions(name=job_name) if x.type == \"Succeeded\"][0])\n",
241+
" break\n",
242+
" elif client.is_job_failed(name=job_name):\n",
243+
" print(\".\")\n",
244+
" print([x.message for x in client.get_job_conditions(name=job_name) if x.type == \"Failed\"][0])\n",
245+
" break\n",
246+
" else:\n",
247+
" print(f\"PyTorchJob '{job_name}' status not available or no conditions found.\")\n",
248+
" break\n",
249+
"\n",
250+
" except Exception as e:\n",
251+
" print(f\"Error getting PyTorchJob status: {e}.\")\n",
252+
"\n",
253+
" time.sleep(3)"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"metadata": {},
260+
"outputs": [],
261+
"source": [
262+
"# Get the job logs\n",
263+
"print(client.get_job_logs(name=job_name)[0][\"fraud-detection-master-0\"])"
264+
]
265+
},
266+
{
267+
"cell_type": "markdown",
268+
"metadata": {
269+
"tags": []
270+
},
271+
"source": [
272+
"### Delete jobs\n",
273+
"\n",
274+
"When finished you can delete the PyTorchJob."
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": null,
280+
"metadata": {
281+
"tags": []
282+
},
283+
"outputs": [],
284+
"source": [
285+
"client.delete_job(name=job_name)"
286+
]
287+
}
288+
],
289+
"metadata": {
290+
"kernelspec": {
291+
"display_name": "Python 3.11",
292+
"language": "python",
293+
"name": "python3"
294+
},
295+
"language_info": {
296+
"codemirror_mode": {
297+
"name": "ipython",
298+
"version": 3
299+
},
300+
"file_extension": ".py",
301+
"mimetype": "text/x-python",
302+
"name": "python",
303+
"nbconvert_exporter": "python",
304+
"pygments_lexer": "ipython3",
305+
"version": "3.11.11"
306+
}
307+
},
308+
"nbformat": 4,
309+
"nbformat_minor": 4
310+
}

0 commit comments

Comments
 (0)