|
17 | 17 | import pytest
|
18 | 18 | from unittest import mock
|
19 | 19 | from dataclasses import dataclass
|
| 20 | +import random |
| 21 | +import string |
20 | 22 |
|
21 | 23 | from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
|
22 | 24 | from flytekit.configuration import Config, ImageConfig, SerializationSettings
|
23 | 25 | from flytekit.core.launch_plan import reference_launch_plan
|
24 | 26 | from flytekit.core.task import reference_task
|
25 | 27 | from flytekit.core.workflow import reference_workflow
|
| 28 | +from flytekit.models import task as task_models |
26 | 29 | from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
|
27 | 30 | from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
|
28 | 31 | from flytekit.remote.remote import FlyteRemote
|
@@ -1170,3 +1173,98 @@ def test_register_wf_twice(register):
|
1170 | 1173 | ]
|
1171 | 1174 | )
|
1172 | 1175 | assert out.returncode == 0
|
| 1176 | + |
| 1177 | + |
| 1178 | +def test_register_wf_with_default_resources_override(register): |
| 1179 | + # Save the version here to retrieve the created task later |
| 1180 | + version = str(uuid.uuid4()) |
| 1181 | + # Register the workflow with overridden default resources |
| 1182 | + out = subprocess.run( |
| 1183 | + [ |
| 1184 | + "pyflyte", |
| 1185 | + "--verbose", |
| 1186 | + "-c", |
| 1187 | + CONFIG, |
| 1188 | + "register", |
| 1189 | + "--default-resources", |
| 1190 | + "cpu=1300m;mem=1100Mi", |
| 1191 | + "--image", |
| 1192 | + IMAGE, |
| 1193 | + "--project", |
| 1194 | + PROJECT, |
| 1195 | + "--domain", |
| 1196 | + DOMAIN, |
| 1197 | + "--version", |
| 1198 | + version, |
| 1199 | + MODULE_PATH / "hello_world.py", |
| 1200 | + ] |
| 1201 | + ) |
| 1202 | + assert out.returncode == 0 |
| 1203 | + |
| 1204 | + # Retrieve the created task |
| 1205 | + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) |
| 1206 | + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) |
| 1207 | + assert task.template.container is not None |
| 1208 | + assert task.template.container.resources == task_models.Resources( |
| 1209 | + requests=[ |
| 1210 | + task_models.Resources.ResourceEntry( |
| 1211 | + name=task_models.Resources.ResourceName.CPU, |
| 1212 | + value="1300m", |
| 1213 | + ), |
| 1214 | + task_models.Resources.ResourceEntry( |
| 1215 | + name=task_models.Resources.ResourceName.MEMORY, |
| 1216 | + value="1100Mi", |
| 1217 | + ), |
| 1218 | + ], |
| 1219 | + limits=[], |
| 1220 | + ) |
| 1221 | + |
| 1222 | + |
| 1223 | +def test_run_wf_with_default_resources_override(register): |
| 1224 | + # Save the execution id here to retrieve the created execution later |
| 1225 | + prefix = random.choice(string.ascii_lowercase) |
| 1226 | + short_random_part = uuid.uuid4().hex[:8] |
| 1227 | + execution_id = f"{prefix}{short_random_part}" |
| 1228 | + # Register the workflow with overridden default resources |
| 1229 | + out = subprocess.run( |
| 1230 | + [ |
| 1231 | + "pyflyte", |
| 1232 | + "--verbose", |
| 1233 | + "-c", |
| 1234 | + CONFIG, |
| 1235 | + "run", |
| 1236 | + "--remote", |
| 1237 | + "--default-resources", |
| 1238 | + "cpu=500m;mem=1Gi", |
| 1239 | + "--project", |
| 1240 | + PROJECT, |
| 1241 | + "--domain", |
| 1242 | + DOMAIN, |
| 1243 | + "--name", |
| 1244 | + execution_id, |
| 1245 | + MODULE_PATH / "hello_world.py", |
| 1246 | + "my_wf" |
| 1247 | + ] |
| 1248 | + ) |
| 1249 | + assert out.returncode == 0 |
| 1250 | + |
| 1251 | + # Retrieve the created task |
| 1252 | + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) |
| 1253 | + execution = remote.fetch_execution(name=execution_id) |
| 1254 | + execution = remote.wait(execution=execution) |
| 1255 | + version = execution.spec.launch_plan.version |
| 1256 | + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) |
| 1257 | + assert task.template.container is not None |
| 1258 | + assert task.template.container.resources == task_models.Resources( |
| 1259 | + requests=[ |
| 1260 | + task_models.Resources.ResourceEntry( |
| 1261 | + name=task_models.Resources.ResourceName.CPU, |
| 1262 | + value="500m", |
| 1263 | + ), |
| 1264 | + task_models.Resources.ResourceEntry( |
| 1265 | + name=task_models.Resources.ResourceName.MEMORY, |
| 1266 | + value="1Gi", |
| 1267 | + ), |
| 1268 | + ], |
| 1269 | + limits=[], |
| 1270 | + ) |
0 commit comments