Skip to content

Commit f1a4b33

Browse files
author
Bruno Grande
authored
Merge pull request #17 from Sage-Bionetworks-Workflows/bgrande/workflows-528/tweaks
[WORKFLOWS-528] Allow `LaunchInfo` values to be missing
2 parents 74f9c5c + e7fe85f commit f1a4b33

File tree

6 files changed

+99
-29
lines changed

6 files changed

+99
-29
lines changed

src/orca/services/nextflowtower/models.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import json
2-
from collections.abc import Collection
32
from dataclasses import field
43
from datetime import datetime
54
from enum import Enum
6-
from typing import Any, Optional
5+
from typing import Any, Iterable, Optional
76

87
from pydantic.dataclasses import dataclass
98
from typing_extensions import Self
109

11-
from orca.services.nextflowtower.utils import parse_datetime
10+
from orca.services.nextflowtower.utils import dedup, parse_datetime
1211

1312

1413
class TaskStatus(Enum):
@@ -91,9 +90,9 @@ def from_json(cls, response: dict[str, Any]) -> Self:
9190
class LaunchInfo:
9291
"""Nextflow Tower workflow launch specification."""
9392

94-
compute_env_id: str
95-
pipeline: str
96-
work_dir: str
93+
pipeline: Optional[str] = None
94+
compute_env_id: Optional[str] = None
95+
work_dir: Optional[str] = None
9796
revision: Optional[str] = None
9897
params: Optional[dict] = None
9998
nextflow_config: Optional[str] = None
@@ -104,18 +103,6 @@ class LaunchInfo:
104103
workspace_secrets: list[str] = field(default_factory=list)
105104
label_ids: list[int] = field(default_factory=list)
106105

107-
@staticmethod
108-
def dedup(items: Collection[str]) -> list[str]:
109-
"""Deduplicate items in a collection.
110-
111-
Args:
112-
items: Collection of items.
113-
114-
Returns:
115-
Deduplicated collection or None.
116-
"""
117-
return list(set(items))
118-
119106
def fill_in(self, attr: str, value: Any):
120107
"""Fill in any missing values.
121108
@@ -126,6 +113,35 @@ def fill_in(self, attr: str, value: Any):
126113
if not getattr(self, attr, None):
127114
setattr(self, attr, value)
128115

116+
def add_in(self, attr: str, values: Iterable[Any]):
117+
"""Add values to a list attribute.
118+
119+
Args:
120+
attr: Attribute name.
121+
values: New attribute values.
122+
"""
123+
current_values = getattr(self, attr)
124+
if not isinstance(current_values, list):
125+
message = f"Attribute '{attr}' is not a list and cannot be extended."
126+
raise ValueError(message)
127+
updated_values = current_values + list(values)
128+
updated_values = dedup(updated_values)
129+
setattr(self, attr, updated_values)
130+
131+
def get(self, name: str) -> Any:
132+
"""Retrieve attribute value, which cannot be None.
133+
134+
Args:
135+
name: Atribute name.
136+
137+
Returns:
138+
Attribute value (not None).
139+
"""
140+
if getattr(self, name, None) is None:
141+
message = f"Attribute '{name}' must be set (not None) by this point."
142+
raise ValueError(message)
143+
return getattr(self, name)
144+
129145
def to_dict(self) -> dict[str, Any]:
130146
"""Generate JSON representation of a launch specification.
131147
@@ -134,19 +150,19 @@ def to_dict(self) -> dict[str, Any]:
134150
"""
135151
output = {
136152
"launch": {
137-
"computeEnvId": self.compute_env_id,
138-
"configProfiles": self.dedup(self.profiles),
153+
"computeEnvId": self.get("compute_env_id"),
154+
"configProfiles": dedup(self.profiles),
139155
"configText": self.nextflow_config,
140156
"dateCreated": None,
141157
"entryName": None,
142158
"headJobCpus": None,
143159
"headJobMemoryMb": None,
144160
"id": None,
145-
"labelIds": self.label_ids,
161+
"labelIds": dedup(self.label_ids),
146162
"mainScript": None,
147163
"optimizationId": None,
148164
"paramsText": json.dumps(self.params),
149-
"pipeline": self.pipeline,
165+
"pipeline": self.get("pipeline"),
150166
"postRunScript": None,
151167
"preRunScript": self.pre_run_script,
152168
"pullLatest": False,
@@ -156,9 +172,9 @@ def to_dict(self) -> dict[str, Any]:
156172
"schemaName": None,
157173
"stubRun": False,
158174
"towerConfig": None,
159-
"userSecrets": self.dedup(self.user_secrets),
160-
"workDir": self.work_dir,
161-
"workspaceSecrets": self.dedup(self.workspace_secrets),
175+
"userSecrets": dedup(self.user_secrets),
176+
"workDir": self.get("work_dir"),
177+
"workspaceSecrets": dedup(self.workspace_secrets),
162178
}
163179
}
164180
return output

src/orca/services/nextflowtower/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ def launch_workflow(
127127
launch_info.fill_in("compute_env_id", compute_env_id)
128128
launch_info.fill_in("work_dir", compute_env.work_dir)
129129
launch_info.fill_in("pre_run_script", compute_env.pre_run_script)
130-
launch_info.fill_in("label_ids", label_ids)
130+
launch_info.add_in("label_ids", label_ids)
131131

132132
return self.client.launch_workflow(launch_info, self.workspace_id)
133133

134+
# TODO: Consider switching return value to a namedtuple
134135
def get_workflow_status(self, workflow_id: str) -> tuple[TaskStatus, bool]:
135136
"""Gets status of workflow run
136137
@@ -146,5 +147,4 @@ def get_workflow_status(self, workflow_id: str) -> tuple[TaskStatus, bool]:
146147
)
147148
task_status = cast(TaskStatus, response["workflow"]["status"])
148149
is_done = task_status in TaskStatus.terminal_states.value
149-
# TODO: Consider switching return value to a namedtuple
150150
return task_status, is_done

src/orca/services/nextflowtower/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from collections.abc import Collection
12
from datetime import datetime, timezone
3+
from typing import TypeVar
4+
5+
T = TypeVar("T", int, str)
26

37

48
def parse_datetime(text: str) -> datetime:
@@ -13,3 +17,15 @@ def parse_datetime(text: str) -> datetime:
1317
parsed = datetime.strptime(text, "%Y-%m-%dT%H:%M:%SZ")
1418
parsed = parsed.replace(tzinfo=timezone.utc)
1519
return parsed
20+
21+
22+
def dedup(items: Collection[T]) -> list[T]:
23+
"""Deduplicate items in a collection.
24+
25+
Args:
26+
items: Collection of items.
27+
28+
Returns:
29+
Deduplicated collection or None.
30+
"""
31+
return list(set(items))

tests/services/nextflowtower/test_integration.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@ def test_that_a_valid_client_can_be_constructed_and_tested(client):
3939
def test_that_a_workflow_can_be_launched(ops):
4040
scratch_bucket = "s3://orca-service-test-project-tower-scratch/"
4141
launch_info = models.LaunchInfo(
42-
compute_env_id="5ykJFs33AE3d3AgThavz3b",
4342
pipeline="nf-core/rnaseq",
4443
revision="3.11.2",
4544
profiles=["test"],
4645
params={"outdir": f"{scratch_bucket}/2days/launch_test"},
47-
work_dir=f"{scratch_bucket}/work",
4846
)
4947
workflow_id = ops.launch_workflow(launch_info, "ondemand")
5048
assert workflow_id
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
3+
from orca.services.nextflowtower.models import LaunchInfo
4+
5+
6+
def test_that_getting_an_launch_info_attribute_works():
7+
launch_info = LaunchInfo(pipeline="foo")
8+
assert launch_info.get("pipeline") == "foo"
9+
10+
11+
def test_for_an_error_when_getting_an_launch_info_attribute_that_is_missing():
12+
launch_info = LaunchInfo()
13+
with pytest.raises(ValueError):
14+
launch_info.get("pipeline")
15+
16+
17+
def test_that_launch_info_attribute_can_be_filled_in():
18+
launch_info = LaunchInfo()
19+
assert launch_info.pipeline is None
20+
launch_info.fill_in("pipeline", "foo")
21+
assert launch_info.pipeline == "foo"
22+
23+
24+
def test_that_launch_info_list_attribute_can_be_added_in():
25+
launch_info = LaunchInfo(label_ids=[1, 2, 3])
26+
assert launch_info.label_ids == [1, 2, 3]
27+
launch_info.add_in("label_ids", [4, 5, 6])
28+
assert launch_info.label_ids == [1, 2, 3, 4, 5, 6]
29+
30+
31+
def test_for_an_error_when_adding_in_with_nonlist_launch_info_attribute():
32+
launch_info = LaunchInfo(pipeline="foo")
33+
with pytest.raises(ValueError):
34+
launch_info.add_in("pipeline", [4, 5, 6])

tests/services/nextflowtower/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@
66
def test_that_parse_datetime_works():
77
result = utils.parse_datetime("2023-04-26T00:49:49Z")
88
assert result == datetime(2023, 4, 26, 0, 49, 49, tzinfo=timezone.utc)
9+
10+
11+
def test_that_launch_info_dedup_works():
12+
secrets = ["foo", "bar", "baz", "foo"]
13+
dedupped = utils.dedup(secrets)
14+
assert len(dedupped) == 3

0 commit comments

Comments
 (0)