Skip to content
This repository was archived by the owner on Nov 1, 2023. It is now read-only.

Commit 1d74379

Browse files
authored
use the primitive types in more places (#514)
1 parent 51f4eea commit 1d74379

File tree

17 files changed

+74
-61
lines changed

17 files changed

+74
-61
lines changed

src/api-service/__app__/onefuzzlib/azure/creds.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from memoization import cached
1717
from msrestazure.azure_active_directory import MSIAuthentication
1818
from msrestazure.tools import parse_resource_id
19-
from onefuzztypes.primitives import Container
19+
from onefuzztypes.primitives import Container, Region
2020

2121
from .monkeypatch import allow_more_workers, reduce_logging
2222

@@ -41,12 +41,12 @@ def get_base_resource_group() -> Any: # should be str
4141

4242

4343
@cached
44-
def get_base_region() -> Any: # should be str
44+
def get_base_region() -> Region:
4545
client = ResourceManagementClient(
4646
credential=get_identity(), subscription_id=get_subscription()
4747
)
4848
group = client.resource_groups.get(get_base_resource_group())
49-
return group.location
49+
return Region(group.location)
5050

5151

5252
@cached
@@ -89,11 +89,11 @@ def get_instance_id() -> UUID:
8989

9090

9191
@cached(ttl=DAY_IN_SECONDS)
92-
def get_regions() -> List[str]:
92+
def get_regions() -> List[Region]:
9393
subscription = get_subscription()
9494
client = SubscriptionClient(credential=get_identity())
9595
locations = client.subscriptions.list_locations(subscription)
96-
return sorted([x.name for x in locations])
96+
return sorted([Region(x.name) for x in locations])
9797

9898

9999
@cached

src/api-service/__app__/onefuzzlib/tasks/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def check_container(
4949
compare: Compare,
5050
expected: int,
5151
container_type: ContainerType,
52-
containers: Dict[ContainerType, List[str]],
52+
containers: Dict[ContainerType, List[Container]],
5353
) -> None:
5454
actual = len(containers.get(container_type, []))
5555
if not check_val(compare, expected, actual):
@@ -62,7 +62,7 @@ def check_container(
6262
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
6363
checked = set()
6464

65-
containers: Dict[ContainerType, List[str]] = {}
65+
containers: Dict[ContainerType, List[Container]] = {}
6666
for container in config.containers:
6767
if container.name not in checked:
6868
if not container_exists(container.name, StorageType.corpus):

src/api-service/__app__/onefuzzlib/tasks/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from onefuzztypes.models import Error
1919
from onefuzztypes.models import Task as BASE_TASK
2020
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
21+
from onefuzztypes.primitives import PoolName
2122

2223
from ..azure.image import get_os
2324
from ..azure.queue import create_queue, delete_queue
@@ -165,7 +166,7 @@ def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
165166
return task
166167

167168
@classmethod
168-
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]:
169+
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
169170
tasks = cls.search_states(states=TaskState.available())
170171
if not tasks:
171172
return []

src/api-service/__app__/onefuzzlib/workers/nodes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def search_states(
7272
*,
7373
scaleset_id: Optional[UUID] = None,
7474
states: Optional[List[NodeState]] = None,
75-
pool_name: Optional[str] = None,
75+
pool_name: Optional[PoolName] = None,
7676
) -> List["Node"]:
7777
query: QueryFilter = {}
7878
if scaleset_id:
@@ -89,7 +89,7 @@ def search_outdated(
8989
*,
9090
scaleset_id: Optional[UUID] = None,
9191
states: Optional[List[NodeState]] = None,
92-
pool_name: Optional[str] = None,
92+
pool_name: Optional[PoolName] = None,
9393
exclude_update_scheduled: bool = False,
9494
num_results: Optional[int] = None,
9595
) -> List["Node"]:

src/cli/examples/oss-fuzz-target.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List, Optional
1313

1414
from onefuzztypes.models import NotificationConfig
15+
from onefuzztypes.primitives import PoolName
1516

1617
from onefuzz.api import Command, Onefuzz
1718
from onefuzz.cli import execute_api
@@ -42,7 +43,7 @@ def fuzz(
4243
self,
4344
project: str,
4445
build: str,
45-
pool: str,
46+
pool: PoolName,
4647
sanitizers: Optional[List[str]] = None,
4748
notification_config: Optional[NotificationConfig] = None,
4849
) -> None:

src/cli/onefuzz/api.py

+31-22
Original file line numberDiff line numberDiff line change
@@ -170,31 +170,34 @@ class Files(Endpoint):
170170
endpoint = "files"
171171

172172
@cached(ttl=ONE_HOUR_IN_SECONDS)
173-
def _get_client(self, container: str) -> ContainerWrapper:
173+
def _get_client(self, container: primitives.Container) -> ContainerWrapper:
174174
sas = self.onefuzz.containers.get(container).sas_url
175175
return ContainerWrapper(sas)
176176

177-
def list(self, container: str) -> models.Files:
177+
def list(self, container: primitives.Container) -> models.Files:
178178
""" Get a list of files in a container """
179179
self.logger.debug("listing files in container: %s", container)
180180
client = self._get_client(container)
181181
return models.Files(files=client.list_blobs())
182182

183-
def delete(self, container: str, filename: str) -> None:
183+
def delete(self, container: primitives.Container, filename: str) -> None:
184184
""" delete a file from a container """
185185
self.logger.debug("deleting in container: %s:%s", container, filename)
186186
client = self._get_client(container)
187187
client.delete_blob(filename)
188188

189-
def get(self, container: str, filename: str) -> bytes:
189+
def get(self, container: primitives.Container, filename: str) -> bytes:
190190
""" get a file from a container """
191191
self.logger.debug("getting file from container: %s:%s", container, filename)
192192
client = self._get_client(container)
193193
downloaded = client.download_blob(filename)
194194
return downloaded
195195

196196
def upload_file(
197-
self, container: str, file_path: str, blob_name: Optional[str] = None
197+
self,
198+
container: primitives.Container,
199+
file_path: str,
200+
blob_name: Optional[str] = None,
198201
) -> None:
199202
""" uploads a file to a container """
200203
if not blob_name:
@@ -212,7 +215,7 @@ def upload_file(
212215
client = self._get_client(container)
213216
client.upload_file(file_path, blob_name)
214217

215-
def upload_dir(self, container: str, dir_path: str) -> None:
218+
def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
216219
""" uploads a directory to a container """
217220

218221
self.logger.debug("uploading directory to container %s:%s", container, dir_path)
@@ -476,7 +479,9 @@ def get(self, vm_id: UUID_EXPANSION) -> models.Repro:
476479
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
477480
)
478481

479-
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
482+
def create(
483+
self, container: primitives.Container, path: str, duration: int = 24
484+
) -> models.Repro:
480485
""" Create a Reproduction VM from a Crash Report """
481486
self.logger.info(
482487
"creating repro vm: %s %s (%d hours)", container, path, duration
@@ -651,7 +656,7 @@ def func() -> Tuple[bool, str, models.Repro]:
651656

652657
def create_and_connect(
653658
self,
654-
container: str,
659+
container: primitives.Container,
655660
path: str,
656661
duration: int = 24,
657662
delete_after_use: bool = False,
@@ -670,14 +675,16 @@ class Notifications(Endpoint):
670675
endpoint = "notifications"
671676

672677
def create(
673-
self, container: str, config: models.NotificationConfig
678+
self, container: primitives.Container, config: models.NotificationConfig
674679
) -> models.Notification:
675680
""" Create a notification based on a config file """
676681

677682
config = requests.NotificationCreate(container=container, config=config.config)
678683
return self._req_model("POST", models.Notification, data=config)
679684

680-
def create_teams(self, container: str, url: str) -> models.Notification:
685+
def create_teams(
686+
self, container: primitives.Container, url: str
687+
) -> models.Notification:
681688
""" Create a Teams notification integration """
682689

683690
self.logger.debug("create teams notification integration: %s", container)
@@ -687,7 +694,7 @@ def create_teams(self, container: str, url: str) -> models.Notification:
687694

688695
def create_ado(
689696
self,
690-
container: str,
697+
container: primitives.Container,
691698
project: str,
692699
base_url: str,
693700
auth_token: str,
@@ -804,7 +811,7 @@ def create(
804811
ensemble_sync_delay: Optional[int] = None,
805812
generator_exe: Optional[str] = None,
806813
generator_options: Optional[List[str]] = None,
807-
pool_name: str,
814+
pool_name: primitives.PoolName,
808815
prereq_tasks: Optional[List[UUID]] = None,
809816
reboot_after_setup: bool = False,
810817
rename_output: bool = False,
@@ -1049,7 +1056,7 @@ def create(
10491056
),
10501057
)
10511058

1052-
def get_config(self, pool_name: str) -> models.AgentConfig:
1059+
def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
10531060
""" Get the agent configuration for the pool """
10541061

10551062
pool = self.get(pool_name)
@@ -1168,17 +1175,19 @@ def list(
11681175
*,
11691176
state: Optional[List[enums.NodeState]] = None,
11701177
scaleset_id: Optional[UUID_EXPANSION] = None,
1171-
pool_name: Optional[str] = None,
1178+
pool_name: Optional[primitives.PoolName] = None,
11721179
) -> List[models.Node]:
11731180
self.logger.debug("list nodes")
11741181
scaleset_id_expanded: Optional[UUID] = None
11751182

11761183
if pool_name is not None:
1177-
pool_name = self._disambiguate(
1178-
"name",
1179-
pool_name,
1180-
lambda x: False,
1181-
lambda: [x.name for x in self.onefuzz.pools.list()],
1184+
pool_name = primitives.PoolName(
1185+
self._disambiguate(
1186+
"name",
1187+
str(pool_name),
1188+
lambda x: False,
1189+
lambda: [x.name for x in self.onefuzz.pools.list()],
1190+
)
11821191
)
11831192

11841193
if scaleset_id is not None:
@@ -1242,12 +1251,12 @@ def _expand_scaleset_machine(
12421251

12431252
def create(
12441253
self,
1245-
pool_name: str,
1254+
pool_name: primitives.PoolName,
12461255
size: int,
12471256
*,
12481257
image: Optional[str] = None,
12491258
vm_sku: Optional[str] = "Standard_D2s_v3",
1250-
region: Optional[str] = None,
1259+
region: Optional[primitives.Region] = None,
12511260
spot_instances: bool = False,
12521261
tags: Optional[Dict[str, str]] = None,
12531262
) -> models.Scaleset:
@@ -1375,7 +1384,7 @@ def delete(
13751384
),
13761385
)
13771386

1378-
def reset(self, region: str) -> responses.BoolResult:
1387+
def reset(self, region: primitives.Region) -> responses.BoolResult:
13791388
""" Reset the proxy for an existing region """
13801389

13811390
return self._req_model(

src/cli/onefuzz/cli.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import jmespath
3333
from docstring_parser import parse as parse_docstring
3434
from msrest.serialization import Model
35-
from onefuzztypes.primitives import Container, Directory, File
35+
from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
3636
from pydantic import BaseModel, ValidationError
3737

3838
LOGGER = logging.getLogger("cli")
@@ -158,6 +158,8 @@ def __init__(self, api_types: List[Any]):
158158
int: {"type": int},
159159
UUID: {"type": UUID},
160160
Container: {"type": str},
161+
Region: {"type": str},
162+
PoolName: {"type": str},
161163
File: {"type": arg_file},
162164
Directory: {"type": arg_dir},
163165
}

src/cli/onefuzz/debug.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from azure.common.client_factory import get_azure_cli_credentials
2121
from onefuzztypes.enums import ContainerType, TaskType
2222
from onefuzztypes.models import BlobRef, NodeAssignment, Report, Task
23-
from onefuzztypes.primitives import Directory
23+
from onefuzztypes.primitives import Container, Directory
2424

2525
from onefuzz.api import UUID_EXPANSION, Command, Onefuzz
2626

@@ -583,13 +583,13 @@ class DebugNotification(Command):
583583

584584
def _get_container(
585585
self, task: Task, container_type: ContainerType
586-
) -> Optional[str]:
586+
) -> Optional[Container]:
587587
for container in task.config.containers:
588588
if container.type == container_type:
589589
return container.name
590590
return None
591591

592-
def _get_storage_account(self, container_name: str) -> str:
592+
def _get_storage_account(self, container_name: Container) -> str:
593593
sas_url = self.onefuzz.containers.get(container_name).sas_url
594594
_, netloc, _, _, _, _ = urlparse(sas_url)
595595
return netloc.split(".")[0]

src/cli/onefuzz/status/cache.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
TaskContainers,
3737
UserInfo,
3838
)
39-
from onefuzztypes.primitives import Container
39+
from onefuzztypes.primitives import Container, PoolName
4040
from pydantic import BaseModel
4141

4242
MESSAGE = Tuple[datetime, EventType, str]
@@ -49,7 +49,7 @@
4949
# status-top only representation of a Node
5050
class MiniNode(BaseModel):
5151
machine_id: UUID
52-
pool_name: str
52+
pool_name: PoolName
5353
state: NodeState
5454

5555

src/cli/onefuzz/templates/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self.project = project
7272
self.name = name
7373
self.build = build
74-
self.to_monitor: Dict[str, int] = {}
74+
self.to_monitor: Dict[Container, int] = {}
7575

7676
if platform is None:
7777
self.platform = JobHelper.get_platform(target_exe)

src/cli/onefuzz/templates/afl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from onefuzztypes.enums import OS, ContainerType, StatsFormat, TaskDebugFlag, TaskType
99
from onefuzztypes.models import Job, NotificationConfig
10-
from onefuzztypes.primitives import Container, Directory, File
10+
from onefuzztypes.primitives import Container, Directory, File, PoolName
1111

1212
from onefuzz.api import Command
1313

@@ -23,7 +23,7 @@ def basic(
2323
name: str,
2424
build: str,
2525
*,
26-
pool_name: str,
26+
pool_name: PoolName,
2727
target_exe: File = File("fuzz.exe"),
2828
setup_dir: Optional[Directory] = None,
2929
vm_count: int = 2,

src/cli/onefuzz/templates/libfuzzer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from onefuzztypes.enums import ContainerType, TaskDebugFlag, TaskType
99
from onefuzztypes.models import Job, NotificationConfig
10-
from onefuzztypes.primitives import Container, Directory, File
10+
from onefuzztypes.primitives import Container, Directory, File, PoolName
1111

1212
from onefuzz.api import Command
1313

@@ -35,7 +35,7 @@ def _create_tasks(
3535
*,
3636
job: Job,
3737
containers: Dict[ContainerType, Container],
38-
pool_name: str,
38+
pool_name: PoolName,
3939
target_exe: str,
4040
vm_count: int = 2,
4141
reboot_after_setup: bool = False,
@@ -145,7 +145,7 @@ def basic(
145145
project: str,
146146
name: str,
147147
build: str,
148-
pool_name: str,
148+
pool_name: PoolName,
149149
*,
150150
target_exe: File = File("fuzz.exe"),
151151
setup_dir: Optional[Directory] = None,
@@ -261,7 +261,7 @@ def merge(
261261
project: str,
262262
name: str,
263263
build: str,
264-
pool_name: str,
264+
pool_name: PoolName,
265265
*,
266266
target_exe: File = File("fuzz.exe"),
267267
setup_dir: Optional[Directory] = None,

0 commit comments

Comments
 (0)