Skip to content

Commit 1a7133a

Browse files
authored
Add ability to specify default resource requests/limits for tasks via pyflyte run and pyflyte register (#3229)
Signed-off-by: redartera <[email protected]>
1 parent 989eb67 commit 1a7133a

File tree

11 files changed

+336
-3
lines changed

11 files changed

+336
-3
lines changed

flytekit/clis/sdk_in_container/register.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from flytekit.configuration import ImageConfig
1616
from flytekit.configuration.default_images import DefaultImages
1717
from flytekit.constants import CopyFileDetection
18-
from flytekit.interaction.click_types import key_value_callback
18+
from flytekit.core.resources import Resources, ResourceSpec
19+
from flytekit.interaction.click_types import key_value_callback, resource_callback
1920
from flytekit.loggers import logger
2021
from flytekit.tools import repo
2122

@@ -134,6 +135,22 @@
134135
callback=key_value_callback,
135136
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
136137
)
138+
@click.option(
139+
"--resource-requests",
140+
required=False,
141+
type=str,
142+
callback=resource_callback,
143+
help="Override default task resource requests for tasks that have no statically defined resource requests in their task decorator. "
144+
"Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'",
145+
)
146+
@click.option(
147+
"--resource-limits",
148+
required=False,
149+
type=str,
150+
callback=resource_callback,
151+
help="Override default task resource limits for tasks that have no statically defined resource limits in their task decorator. "
152+
"Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'",
153+
)
137154
@click.option(
138155
"--skip-errors",
139156
"--skip-error",
@@ -161,6 +178,8 @@ def register(
161178
dry_run: bool,
162179
activate_launchplans: bool,
163180
env: typing.Optional[typing.Dict[str, str]],
181+
resource_requests: typing.Optional[Resources],
182+
resource_limits: typing.Optional[Resources],
164183
skip_errors: bool,
165184
):
166185
"""
@@ -225,6 +244,9 @@ def register(
225244
package_or_module=package_or_module,
226245
remote=remote,
227246
env=env,
247+
default_resources=ResourceSpec(
248+
requests=resource_requests or Resources(), limits=resource_limits or Resources()
249+
),
228250
dry_run=dry_run,
229251
activate_launchplans=activate_launchplans,
230252
skip_errors=skip_errors,

flytekit/clis/sdk_in_container/run.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from flytekit.core.artifact import ArtifactQuery
4444
from flytekit.core.base_task import PythonTask
4545
from flytekit.core.data_persistence import FileAccessProvider
46+
from flytekit.core.resources import Resources, ResourceSpec
4647
from flytekit.core.type_engine import TypeEngine
4748
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
4849
from flytekit.exceptions.system import FlyteSystemException
@@ -51,6 +52,7 @@
5152
FlyteLiteralConverter,
5253
key_value_callback,
5354
labels_callback,
55+
resource_callback,
5456
)
5557
from flytekit.interaction.string_literals import literal_string_repr
5658
from flytekit.loggers import logger
@@ -208,6 +210,28 @@ class RunLevelParams(PyFlyteParams):
208210
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
209211
)
210212
)
213+
resource_requests: typing.Optional[Resources] = make_click_option_field(
214+
click.Option(
215+
param_decls=["--resource-requests"],
216+
required=False,
217+
show_default=True,
218+
type=str,
219+
callback=resource_callback,
220+
help="This overrides default task resource requests for tasks that have no statically defined resource requests in their task decorator. "
221+
"Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'",
222+
)
223+
)
224+
resource_limits: typing.Optional[Resources] = make_click_option_field(
225+
click.Option(
226+
param_decls=["--resource-limits"],
227+
required=False,
228+
show_default=True,
229+
type=str,
230+
callback=resource_callback,
231+
help="This overrides default task resource limits for tasks that have no statically defined resource limits in their task decorator. "
232+
"Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'",
233+
)
234+
)
211235
tags: typing.List[str] = make_click_option_field(
212236
click.Option(
213237
param_decls=["--tags", "--tag"],
@@ -756,6 +780,10 @@ def _run(*args, **kwargs):
756780
source_path=run_level_params.computed_params.project_root,
757781
module_name=run_level_params.computed_params.module,
758782
fast_package_options=fast_package_options,
783+
default_resources=ResourceSpec(
784+
requests=run_level_params.resource_requests or Resources(),
785+
limits=run_level_params.resource_limits or Resources(),
786+
),
759787
)
760788

761789
run_remote(

flytekit/configuration/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
from flytekit.configuration import internal as _internal
130130
from flytekit.configuration.default_images import DefaultImages
131131
from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists
132+
from flytekit.core.resources import ResourceSpec
132133
from flytekit.image_spec import ImageSpec
133134
from flytekit.image_spec.image_spec import ImageBuildEngine
134135
from flytekit.loggers import logger
@@ -805,6 +806,8 @@ class SerializationSettings(DataClassJsonMixin):
805806
version (str): The version (if any) with which to register entities under.
806807
image_config (ImageConfig): The image config used to define task container images.
807808
env (Optional[Dict[str, str]]): Environment variables injected into task container definitions.
809+
default_resources (Optional[ResourceSpec]): The resources to request for the task - this is useful
810+
if users need to override the default resource spec of an entity at registration time.
808811
flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit
809812
virtualenv at serialization time won't match the in-container value at execution time. This optional value
810813
is used to provide the in-container virtualenv path
@@ -823,6 +826,7 @@ class SerializationSettings(DataClassJsonMixin):
823826
domain: typing.Optional[str] = None
824827
version: typing.Optional[str] = None
825828
env: Optional[Dict[str, str]] = None
829+
default_resources: Optional[ResourceSpec] = None
826830
git_repo: Optional[str] = None
827831
python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER
828832
flytekit_virtualenv_root: Optional[str] = None
@@ -897,6 +901,7 @@ def new_builder(self) -> Builder:
897901
version=self.version,
898902
image_config=self.image_config,
899903
env=self.env.copy() if self.env else None,
904+
default_resources=self.default_resources,
900905
git_repo=self.git_repo,
901906
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
902907
python_interpreter=self.python_interpreter,
@@ -948,6 +953,7 @@ class Builder(object):
948953
version: str
949954
image_config: ImageConfig
950955
env: Optional[Dict[str, str]] = None
956+
default_resources: Optional[ResourceSpec] = None
951957
git_repo: Optional[str] = None
952958
flytekit_virtualenv_root: Optional[str] = None
953959
python_interpreter: Optional[str] = None
@@ -965,6 +971,7 @@ def build(self) -> SerializationSettings:
965971
version=self.version,
966972
image_config=self.image_config,
967973
env=self.env,
974+
default_resources=self.default_resources,
968975
git_repo=self.git_repo,
969976
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
970977
python_interpreter=self.python_interpreter,

flytekit/core/python_auto_container.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
233233
if elem:
234234
env.update(elem)
235235

236+
# Override the task's resource spec if it was not set statically in the task definition
237+
238+
def _resources_unspecified(resources: ResourceSpec) -> bool:
239+
return resources == ResourceSpec(
240+
requests=Resources(),
241+
limits=Resources(),
242+
)
243+
244+
if isinstance(settings.default_resources, ResourceSpec) and _resources_unspecified(self.resources):
245+
self._resources = settings.default_resources
246+
236247
# Add runtime dependencies into environment
237248
if isinstance(self.container_image, ImageSpec) and self.container_image.runtime_packages:
238249
runtime_packages = " ".join(self.container_image.runtime_packages)

flytekit/interaction/click_types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset
2424
from flytekit.core.artifact import ArtifactQuery
2525
from flytekit.core.data_persistence import FileAccessProvider
26+
from flytekit.core.resources import Resources
2627
from flytekit.core.type_engine import TypeEngine
2728
from flytekit.models.types import SimpleType
2829
from flytekit.remote.remote_fs import FlytePathResolver
@@ -84,6 +85,33 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi
8485
return result
8586

8687

88+
def resource_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[Resources]:
89+
"""
90+
Click callback to parse resource strings like 'cpu=1,mem=2Gi' into a Resources object
91+
"""
92+
if not value:
93+
return None
94+
95+
items = value.split(",")
96+
_allowed_keys = Resources.__annotations__.keys()
97+
result = {}
98+
for item in items:
99+
kv_split = item.split("=")
100+
if len(kv_split) != 2:
101+
raise click.BadParameter(
102+
f"Expected comma separated key-value pairs of the form 'key1=value1,key2=value2,...', got '{item}'"
103+
)
104+
k = kv_split[0].strip()
105+
v = kv_split[1].strip()
106+
if k not in _allowed_keys:
107+
raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, but got '{k}'")
108+
if k in result:
109+
raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, but got '{k}' multiple times")
110+
result[k] = v
111+
112+
return Resources(**result)
113+
114+
87115
class DirParamType(click.ParamType):
88116
name = "directory path"
89117

flytekit/remote/remote.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757
from flytekit.core.python_function_task import PythonFunctionTask
5858
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec
59+
from flytekit.core.resources import ResourceSpec
5960
from flytekit.core.task import ReferenceTask
6061
from flytekit.core.tracker import extract_task_module
6162
from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching
@@ -1326,6 +1327,7 @@ def register_script(
13261327
source_path: typing.Optional[str] = None,
13271328
module_name: typing.Optional[str] = None,
13281329
envs: typing.Optional[typing.Dict[str, str]] = None,
1330+
default_resources: typing.Optional[ResourceSpec] = None,
13291331
fast_package_options: typing.Optional[FastPackageOptions] = None,
13301332
) -> typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan, ReferenceEntity]:
13311333
"""
@@ -1342,6 +1344,7 @@ def register_script(
13421344
:param source_path: The root of the project path
13431345
:param module_name: the name of the module
13441346
:param envs: Environment variables to be passed to the serialization
1347+
:param default_resources: Default resources to be passed to the serialization. These override the resource spec for any tasks that have no statically defined resource requests and limits.
13451348
:param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False.
13461349
:return:
13471350
"""
@@ -1380,6 +1383,7 @@ def register_script(
13801383
image_config=image_config,
13811384
git_repo=_get_git_repo_url(source_path),
13821385
env=envs,
1386+
default_resources=default_resources,
13831387
fast_serialization_settings=FastSerializationSettings(
13841388
enabled=True,
13851389
destination_dir=destination_dir,

flytekit/tools/repo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from flytekit.constants import CopyFileDetection
1414
from flytekit.core.base_task import PythonTask
1515
from flytekit.core.context_manager import FlyteContextManager, FlyteEntities
16+
from flytekit.core.resources import ResourceSpec
1617
from flytekit.loggers import logger
1718
from flytekit.models import launch_plan, task
1819
from flytekit.models.core.identifier import Identifier
@@ -252,6 +253,7 @@ def register(
252253
remote: FlyteRemote,
253254
copy_style: CopyFileDetection,
254255
env: typing.Optional[typing.Dict[str, str]],
256+
default_resources: typing.Optional[ResourceSpec],
255257
dry_run: bool = False,
256258
activate_launchplans: bool = False,
257259
skip_errors: bool = False,
@@ -274,6 +276,7 @@ def register(
274276
image_config=image_config,
275277
fast_serialization_settings=None, # should probably add incomplete fast settings
276278
env=env,
279+
default_resources=default_resources,
277280
)
278281

279282
if not version and copy_style == CopyFileDetection.NO_COPY:

tests/flytekit/integration/remote/test_remote.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
import uuid
1717
import pytest
1818
from unittest import mock
19+
import random
20+
import string
1921
from dataclasses import asdict, dataclass
2022

2123
from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
2224
from flytekit.configuration import Config, ImageConfig, SerializationSettings
2325
from flytekit.core.launch_plan import reference_launch_plan
2426
from flytekit.core.task import reference_task
2527
from flytekit.core.workflow import reference_workflow
28+
from flytekit.models import task as task_models
2629
from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
2730
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
2831
from flytekit.remote.remote import FlyteRemote
@@ -1252,3 +1255,106 @@ def test_register_wf_twice(register):
12521255
]
12531256
)
12541257
assert out.returncode == 0
1258+
1259+
1260+
def test_register_wf_with_resource_requests_override(register):
1261+
# Save the version here to retrieve the created task later
1262+
version = str(uuid.uuid4())
1263+
1264+
cpu = "1300m"
1265+
mem = "1100Mi"
1266+
1267+
# Register the workflow with overridden default resources
1268+
out = subprocess.run(
1269+
[
1270+
"pyflyte",
1271+
"--verbose",
1272+
"-c",
1273+
CONFIG,
1274+
"register",
1275+
"--resource-requests",
1276+
f"cpu={cpu},mem={mem}",
1277+
"--image",
1278+
IMAGE,
1279+
"--project",
1280+
PROJECT,
1281+
"--domain",
1282+
DOMAIN,
1283+
"--version",
1284+
version,
1285+
MODULE_PATH / "hello_world.py",
1286+
]
1287+
)
1288+
assert out.returncode == 0
1289+
1290+
# Retrieve the created task
1291+
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
1292+
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
1293+
assert task.template.container is not None
1294+
assert task.template.container.resources == task_models.Resources(
1295+
requests=[
1296+
task_models.Resources.ResourceEntry(
1297+
name=task_models.Resources.ResourceName.CPU,
1298+
value=cpu,
1299+
),
1300+
task_models.Resources.ResourceEntry(
1301+
name=task_models.Resources.ResourceName.MEMORY,
1302+
value=mem,
1303+
),
1304+
],
1305+
limits=[],
1306+
)
1307+
1308+
1309+
def test_run_wf_with_resource_requests_override(register):
1310+
# Save the execution id here to retrieve the created execution later
1311+
prefix = random.choice(string.ascii_lowercase)
1312+
short_random_part = uuid.uuid4().hex[:8]
1313+
execution_id = f"{prefix}{short_random_part}"
1314+
1315+
cpu = "500m"
1316+
mem = "1Gi"
1317+
1318+
# Register the workflow with overridden default resources
1319+
out = subprocess.run(
1320+
[
1321+
"pyflyte",
1322+
"--verbose",
1323+
"-c",
1324+
CONFIG,
1325+
"run",
1326+
"--remote",
1327+
"--resource-requests",
1328+
f"cpu={cpu},mem={mem}",
1329+
"--project",
1330+
PROJECT,
1331+
"--domain",
1332+
DOMAIN,
1333+
"--name",
1334+
execution_id,
1335+
MODULE_PATH / "hello_world.py",
1336+
"my_wf"
1337+
]
1338+
)
1339+
assert out.returncode == 0
1340+
1341+
# Retrieve the created task
1342+
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
1343+
execution = remote.fetch_execution(name=execution_id)
1344+
execution = remote.wait(execution=execution)
1345+
version = execution.spec.launch_plan.version
1346+
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
1347+
assert task.template.container is not None
1348+
assert task.template.container.resources == task_models.Resources(
1349+
requests=[
1350+
task_models.Resources.ResourceEntry(
1351+
name=task_models.Resources.ResourceName.CPU,
1352+
value=cpu,
1353+
),
1354+
task_models.Resources.ResourceEntry(
1355+
name=task_models.Resources.ResourceName.MEMORY,
1356+
value=mem,
1357+
),
1358+
],
1359+
limits=[],
1360+
)

0 commit comments

Comments
 (0)