Skip to content

Commit 8291b77

Browse files
committed
add support for default task resource overrides
1 parent 8abecfe commit 8291b77

File tree

11 files changed

+311
-3
lines changed

11 files changed

+311
-3
lines changed

flytekit/clis/sdk_in_container/register.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from flytekit.configuration import ImageConfig
1616
from flytekit.configuration.default_images import DefaultImages
1717
from flytekit.constants import CopyFileDetection
18-
from flytekit.interaction.click_types import key_value_callback
18+
from flytekit.core.resources import ResourceSpec
19+
from flytekit.interaction.click_types import key_value_callback, resource_spec_callback
1920
from flytekit.loggers import logger
2021
from flytekit.tools import repo
2122

@@ -134,6 +135,15 @@
134135
callback=key_value_callback,
135136
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
136137
)
138+
@click.option(
139+
"--default-resources",
140+
required=False,
141+
type=str,
142+
callback=resource_spec_callback,
143+
help="Override default task resource requests and limits for tasks that have no statically defined resource request and limit. "
144+
"""Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """
145+
"""--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""",
146+
)
137147
@click.option(
138148
"--skip-errors",
139149
"--skip-error",
@@ -161,6 +171,7 @@ def register(
161171
dry_run: bool,
162172
activate_launchplans: bool,
163173
env: typing.Optional[typing.Dict[str, str]],
174+
default_resources: typing.Optional[ResourceSpec],
164175
skip_errors: bool,
165176
):
166177
"""
@@ -225,6 +236,7 @@ def register(
225236
package_or_module=package_or_module,
226237
remote=remote,
227238
env=env,
239+
default_resources=default_resources,
228240
dry_run=dry_run,
229241
activate_launchplans=activate_launchplans,
230242
skip_errors=skip_errors,

flytekit/clis/sdk_in_container/run.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from flytekit.core.artifact import ArtifactQuery
4444
from flytekit.core.base_task import PythonTask
4545
from flytekit.core.data_persistence import FileAccessProvider
46+
from flytekit.core.resources import ResourceSpec
4647
from flytekit.core.type_engine import TypeEngine
4748
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
4849
from flytekit.exceptions.system import FlyteSystemException
@@ -51,6 +52,7 @@
5152
FlyteLiteralConverter,
5253
key_value_callback,
5354
labels_callback,
55+
resource_spec_callback,
5456
)
5557
from flytekit.interaction.string_literals import literal_string_repr
5658
from flytekit.loggers import logger
@@ -197,6 +199,18 @@ class RunLevelParams(PyFlyteParams):
197199
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
198200
)
199201
)
202+
default_resources: typing.Optional[ResourceSpec] = make_click_option_field(
203+
click.Option(
204+
param_decls=["--default-resources"],
205+
required=False,
206+
show_default=True,
207+
type=str,
208+
callback=resource_spec_callback,
209+
help="During fast registration, will override default task resource requests and limits for tasks that have no statically defined resource request and limit. "
210+
"""Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """
211+
"""--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""",
212+
)
213+
)
200214
tags: typing.List[str] = make_click_option_field(
201215
click.Option(
202216
param_decls=["--tags", "--tag"],
@@ -745,6 +759,7 @@ def _run(*args, **kwargs):
745759
source_path=run_level_params.computed_params.project_root,
746760
module_name=run_level_params.computed_params.module,
747761
fast_package_options=fast_package_options,
762+
default_resources=run_level_params.default_resources,
748763
)
749764

750765
run_remote(

flytekit/configuration/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
from flytekit.configuration import internal as _internal
149149
from flytekit.configuration.default_images import DefaultImages
150150
from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists
151+
from flytekit.core.resources import ResourceSpec
151152
from flytekit.image_spec import ImageSpec
152153
from flytekit.image_spec.image_spec import ImageBuildEngine
153154
from flytekit.loggers import logger
@@ -824,6 +825,8 @@ class SerializationSettings(DataClassJsonMixin):
824825
version (str): The version (if any) with which to register entities under.
825826
image_config (ImageConfig): The image config used to define task container images.
826827
env (Optional[Dict[str, str]]): Environment variables injected into task container definitions.
828+
default_resources (Optional[ResourceSpec]): The resources to request for the task - this is useful
829+
if users need to override the default resource spec of an entity at registration time.
827830
flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit
828831
virtualenv at serialization time won't match the in-container value at execution time. This optional value
829832
is used to provide the in-container virtualenv path
@@ -842,6 +845,7 @@ class SerializationSettings(DataClassJsonMixin):
842845
domain: typing.Optional[str] = None
843846
version: typing.Optional[str] = None
844847
env: Optional[Dict[str, str]] = None
848+
default_resources: Optional[ResourceSpec] = None
845849
git_repo: Optional[str] = None
846850
python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER
847851
flytekit_virtualenv_root: Optional[str] = None
@@ -916,6 +920,7 @@ def new_builder(self) -> Builder:
916920
version=self.version,
917921
image_config=self.image_config,
918922
env=self.env.copy() if self.env else None,
923+
default_resources=self.default_resources,
919924
git_repo=self.git_repo,
920925
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
921926
python_interpreter=self.python_interpreter,
@@ -967,6 +972,7 @@ class Builder(object):
967972
version: str
968973
image_config: ImageConfig
969974
env: Optional[Dict[str, str]] = None
975+
default_resources: Optional[ResourceSpec] = None
970976
git_repo: Optional[str] = None
971977
flytekit_virtualenv_root: Optional[str] = None
972978
python_interpreter: Optional[str] = None
@@ -984,6 +990,7 @@ def build(self) -> SerializationSettings:
984990
version=self.version,
985991
image_config=self.image_config,
986992
env=self.env,
993+
default_resources=self.default_resources,
987994
git_repo=self.git_repo,
988995
flytekit_virtualenv_root=self.flytekit_virtualenv_root,
989996
python_interpreter=self.python_interpreter,

flytekit/core/python_auto_container.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
231231
for elem in (settings.env, self.environment):
232232
if elem:
233233
env.update(elem)
234+
# Override the task's resource spec if it was not set statically in the task definition
235+
236+
def _resources_unspecified(resources: ResourceSpec) -> bool:
237+
return resources == ResourceSpec(
238+
requests=Resources(),
239+
limits=Resources(),
240+
)
241+
242+
if isinstance(settings.default_resources, ResourceSpec) and _resources_unspecified(self.resources):
243+
self._resources = settings.default_resources
244+
234245
return _get_container_definition(
235246
image=self.get_image(settings),
236247
resource_spec=self.resources,

flytekit/interaction/click_types.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset
2121
from flytekit.core.artifact import ArtifactQuery
2222
from flytekit.core.data_persistence import FileAccessProvider
23+
from flytekit.core.resources import Resources, ResourceSpec
2324
from flytekit.core.type_engine import TypeEngine
2425
from flytekit.models.types import SimpleType
2526
from flytekit.remote.remote_fs import FlytePathResolver
@@ -81,6 +82,40 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi
8182
return result
8283

8384

85+
def resource_spec_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[ResourceSpec]:
86+
"""
87+
Callback for click to parse a resource spec.
88+
"""
89+
if not value:
90+
return None
91+
92+
def _extract_pair(s: str) -> typing.Optional[typing.Tuple[str, str]]:
93+
"""Can extract the pair of values "0.5" and "1" from the string '(0.5,1)'"""
94+
vals = s.strip("() ").split(",")
95+
if len(vals) != 2:
96+
return None
97+
return vals[0].strip(), vals[1].strip()
98+
99+
items = value.split(";")
100+
_allowed_keys = Resources.__annotations__.keys()
101+
result = {}
102+
for item in items:
103+
kv_split = item.split("=")
104+
if len(kv_split) != 2:
105+
raise click.BadParameter(
106+
f"Expected semicolon separated key-value pairs of the form 'key1=value1;key2=value2;...', got '{item}'"
107+
)
108+
k = kv_split[0].strip()
109+
v = kv_split[1].strip()
110+
if k not in _allowed_keys:
111+
raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, got '{k}'")
112+
if k in result:
113+
raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, got '{k}' multiple times")
114+
result[k.strip()] = _extract_pair(v) or v
115+
116+
return ResourceSpec.from_multiple_resource(Resources(**result))
117+
118+
84119
class DirParamType(click.ParamType):
85120
name = "directory path"
86121

flytekit/remote/remote.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757
from flytekit.core.python_function_task import PythonFunctionTask
5858
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec
59+
from flytekit.core.resources import ResourceSpec
5960
from flytekit.core.task import ReferenceTask
6061
from flytekit.core.tracker import extract_task_module
6162
from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching
@@ -1310,6 +1311,7 @@ def register_script(
13101311
source_path: typing.Optional[str] = None,
13111312
module_name: typing.Optional[str] = None,
13121313
envs: typing.Optional[typing.Dict[str, str]] = None,
1314+
default_resources: typing.Optional[ResourceSpec] = None,
13131315
fast_package_options: typing.Optional[FastPackageOptions] = None,
13141316
) -> typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan, ReferenceEntity]:
13151317
"""
@@ -1326,6 +1328,7 @@ def register_script(
13261328
:param source_path: The root of the project path
13271329
:param module_name: the name of the module
13281330
:param envs: Environment variables to be passed to the serialization
1331+
:param default_resources: Default resources to be passed to the serialization. These override the resource spec for any tasks that have no statically defined resource requests and limits.
13291332
:param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False.
13301333
:return:
13311334
"""
@@ -1364,6 +1367,7 @@ def register_script(
13641367
image_config=image_config,
13651368
git_repo=_get_git_repo_url(source_path),
13661369
env=envs,
1370+
default_resources=default_resources,
13671371
fast_serialization_settings=FastSerializationSettings(
13681372
enabled=True,
13691373
destination_dir=destination_dir,

flytekit/tools/repo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
1313
from flytekit.constants import CopyFileDetection
1414
from flytekit.core.context_manager import FlyteContextManager
15+
from flytekit.core.resources import ResourceSpec
1516
from flytekit.loggers import logger
1617
from flytekit.models import launch_plan, task
1718
from flytekit.models.core.identifier import Identifier
@@ -251,6 +252,7 @@ def register(
251252
remote: FlyteRemote,
252253
copy_style: CopyFileDetection,
253254
env: typing.Optional[typing.Dict[str, str]],
255+
default_resources: typing.Optional[ResourceSpec],
254256
dry_run: bool = False,
255257
activate_launchplans: bool = False,
256258
skip_errors: bool = False,
@@ -273,6 +275,7 @@ def register(
273275
image_config=image_config,
274276
fast_serialization_settings=None, # should probably add incomplete fast settings
275277
env=env,
278+
default_resources=default_resources,
276279
)
277280

278281
if not version and copy_style == CopyFileDetection.NO_COPY:

tests/flytekit/integration/remote/test_remote.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
import pytest
1818
from unittest import mock
1919
from dataclasses import dataclass
20+
import random
21+
import string
2022

2123
from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
2224
from flytekit.configuration import Config, ImageConfig, SerializationSettings
2325
from flytekit.core.launch_plan import reference_launch_plan
2426
from flytekit.core.task import reference_task
2527
from flytekit.core.workflow import reference_workflow
28+
from flytekit.models import task as task_models
2629
from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
2730
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
2831
from flytekit.remote.remote import FlyteRemote
@@ -1170,3 +1173,98 @@ def test_register_wf_twice(register):
11701173
]
11711174
)
11721175
assert out.returncode == 0
1176+
1177+
1178+
def test_register_wf_with_default_resources_override(register):
1179+
# Save the version here to retrieve the created task later
1180+
version = str(uuid.uuid4())
1181+
# Register the workflow with overridden default resources
1182+
out = subprocess.run(
1183+
[
1184+
"pyflyte",
1185+
"--verbose",
1186+
"-c",
1187+
CONFIG,
1188+
"register",
1189+
"--default-resources",
1190+
"cpu=1300m;mem=1100Mi",
1191+
"--image",
1192+
IMAGE,
1193+
"--project",
1194+
PROJECT,
1195+
"--domain",
1196+
DOMAIN,
1197+
"--version",
1198+
version,
1199+
MODULE_PATH / "hello_world.py",
1200+
]
1201+
)
1202+
assert out.returncode == 0
1203+
1204+
# Retrieve the created task
1205+
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
1206+
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
1207+
assert task.template.container is not None
1208+
assert task.template.container.resources == task_models.Resources(
1209+
requests=[
1210+
task_models.Resources.ResourceEntry(
1211+
name=task_models.Resources.ResourceName.CPU,
1212+
value="1300m",
1213+
),
1214+
task_models.Resources.ResourceEntry(
1215+
name=task_models.Resources.ResourceName.MEMORY,
1216+
value="1100Mi",
1217+
),
1218+
],
1219+
limits=[],
1220+
)
1221+
1222+
1223+
def test_run_wf_with_default_resources_override(register):
1224+
# Save the execution id here to retrieve the created execution later
1225+
prefix = random.choice(string.ascii_lowercase)
1226+
short_random_part = uuid.uuid4().hex[:8]
1227+
execution_id = f"{prefix}{short_random_part}"
1228+
# Register the workflow with overridden default resources
1229+
out = subprocess.run(
1230+
[
1231+
"pyflyte",
1232+
"--verbose",
1233+
"-c",
1234+
CONFIG,
1235+
"run",
1236+
"--remote",
1237+
"--default-resources",
1238+
"cpu=500m;mem=1Gi",
1239+
"--project",
1240+
PROJECT,
1241+
"--domain",
1242+
DOMAIN,
1243+
"--name",
1244+
execution_id,
1245+
MODULE_PATH / "hello_world.py",
1246+
"my_wf"
1247+
]
1248+
)
1249+
assert out.returncode == 0
1250+
1251+
# Retrieve the created task
1252+
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
1253+
execution = remote.fetch_execution(name=execution_id)
1254+
execution = remote.wait(execution=execution)
1255+
version = execution.spec.launch_plan.version
1256+
task = remote.fetch_task(name="basic.hello_world.say_hello", version=version)
1257+
assert task.template.container is not None
1258+
assert task.template.container.resources == task_models.Resources(
1259+
requests=[
1260+
task_models.Resources.ResourceEntry(
1261+
name=task_models.Resources.ResourceName.CPU,
1262+
value="500m",
1263+
),
1264+
task_models.Resources.ResourceEntry(
1265+
name=task_models.Resources.ResourceName.MEMORY,
1266+
value="1Gi",
1267+
),
1268+
],
1269+
limits=[],
1270+
)

tests/flytekit/unit/core/test_context_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from flytekit.core import mock_stats, context_manager
2121
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, SecretsManager
22+
from flytekit.core.resources import ResourceSpec, Resources
2223
from flytekit.models.core import identifier as id_models
2324

2425

@@ -301,6 +302,7 @@ def test_serialization_settings_transport():
301302
domain="domain",
302303
version="version",
303304
env={"hello": "blah"},
305+
default_resources=ResourceSpec(requests=Resources(cpu="1", mem="2Gi"), limits=Resources(cpu="1", mem="2Gi")),
304306
image_config=ImageConfig(
305307
default_image=default_img,
306308
images=[default_img],
@@ -322,7 +324,7 @@ def test_serialization_settings_transport():
322324
ss = SerializationSettings.from_transport(tp)
323325
assert ss is not None
324326
assert ss == serialization_settings
325-
assert len(tp) == 408
327+
assert len(tp) == 480
326328

327329

328330
def test_exec_params():

0 commit comments

Comments
 (0)