Skip to content

Commit 2346973

Browse files
add access updates for tiled (#217)
* add access updates for tiled * Update utils/data_utils.py Co-authored-by: Wiebke Köpp <wkoepp@lbl.gov> * remove ggshield * reduce size of np fixture --------- Co-authored-by: Wiebke Köpp <wkoepp@lbl.gov>
1 parent 59de389 commit 2346973

File tree

6 files changed

+163
-9
lines changed

6 files changed

+163
-9
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ MLFLOW_TRACKING_USERNAME=
5252
MLFLOW_TRACKING_PASSWORD=
5353
#algorithm registry in mlflow
5454
MLFLOW_TRACKING_URI_OUTSIDE=http://localhost:5000
55-
ALGORITHM_JSON_PATH="../assets/models.json"
55+
ALGORITHM_JSON_PATH="../assets/models.json"

.pre-commit-config.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ repos:
1212
- id: check-symlinks
1313
- id: check-yaml
1414
- id: debug-statements
15-
- repo: https://github.com/gitguardian/ggshield
16-
rev: v1.25.0
17-
hooks:
18-
- id: ggshield
19-
language_version: python3
20-
stages: [commit]
2115
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
2216
- repo: https://github.com/psf/black-pre-commit-mirror
2317
rev: 24.2.0

scripts/save_mlflow_algorithm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
# Add the project root directory to Python path to fix imports
99
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1010

11-
from mlex_utils.mlflow_utils.mlflow_algorithm_client import MlflowAlgorithmClient
11+
from mlex_utils.mlflow_utils.mlflow_algorithm_client import ( # noqa: E402
12+
MlflowAlgorithmClient,
13+
)
1214

1315
# Load environment variables from .env file
1416
load_dotenv(dotenv_path="../.env")

tests/unit/test_control_bar.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
1+
import sys
2+
13
import numpy as np
24
import pytest
35

46

7+
class DummyMlflowAlgorithmClient:
8+
def __init__(self, *args, **kwargs):
9+
self.modelname_list = []
10+
11+
def load_from_mlflow(self, algorithm_type="segmentation"):
12+
return None
13+
14+
def __getitem__(self, key):
15+
raise KeyError(key)
16+
17+
518
@pytest.fixture
619
def tiled_data_mock():
7-
return {"sample_project": np.zeros((500, 500, 500)), "reconstruction": 0}
20+
return {"sample_project": np.zeros((2, 500, 500)), "reconstruction": 0}
821

922

1023
@pytest.fixture
@@ -17,6 +30,12 @@ def test_reset_filters(mocker, nclicks, tiled_data_mock):
1730
"tiled.client.from_uri",
1831
return_value=tiled_data_mock,
1932
)
33+
mocker.patch(
34+
"mlex_utils.mlflow_utils.mlflow_algorithm_client.MlflowAlgorithmClient",
35+
DummyMlflowAlgorithmClient,
36+
)
37+
sys.modules.pop("callbacks.control_bar", None)
38+
sys.modules.pop("utils.data_utils", None)
2039
from callbacks.control_bar import reset_filters
2140

2241
assert reset_filters(nclicks) == 100

tests/unit/test_data_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import importlib
2+
import sys
3+
from types import SimpleNamespace
4+
5+
6+
class DummyMlflowAlgorithmClient:
7+
def __init__(self, *args, **kwargs):
8+
self.modelname_list = []
9+
10+
def load_from_mlflow(self, algorithm_type="segmentation"):
11+
return None
12+
13+
def __getitem__(self, key):
14+
raise KeyError(key)
15+
16+
17+
class FakeContainer:
18+
def __init__(self, uri="tiled://root"):
19+
self.uri = uri
20+
self._children = {}
21+
self.metadata = None
22+
23+
def keys(self):
24+
return self._children.keys()
25+
26+
def create_container(self, key, metadata=None):
27+
container = FakeContainer(uri=f"{self.uri}/{key}")
28+
container.metadata = metadata
29+
self._children[key] = container
30+
return container
31+
32+
def write_array(self, key, array):
33+
array_client = SimpleNamespace(uri=f"{self.uri}/{key}", array=array)
34+
self._children[key] = array_client
35+
return array_client
36+
37+
def __getitem__(self, key):
38+
return self._children[key]
39+
40+
41+
def test_save_annotations_data(mocker, monkeypatch):
42+
monkeypatch.setenv("DATA_TILED_URI", "http://example.com/api/v1/metadata/data")
43+
monkeypatch.setenv("MASK_TILED_URI", "http://example.com/api/v1/metadata/masks")
44+
monkeypatch.setenv("SEG_TILED_URI", "http://example.com/api/v1/metadata/seg")
45+
46+
mocker.patch("tiled.client.from_uri", return_value={})
47+
mocker.patch(
48+
"mlex_utils.mlflow_utils.mlflow_algorithm_client.MlflowAlgorithmClient",
49+
DummyMlflowAlgorithmClient,
50+
)
51+
52+
sys.modules.pop("utils.data_utils", None)
53+
data_utils = importlib.import_module("utils.data_utils")
54+
55+
root_container = FakeContainer()
56+
mask_handler = data_utils.TiledMaskHandler.__new__(data_utils.TiledMaskHandler)
57+
mask_handler.mask_client = root_container
58+
59+
mocker.patch.object(
60+
data_utils, "from_uri", return_value=SimpleNamespace(access_blob={})
61+
)
62+
mocker.patch.object(data_utils, "copy_tiled_access_info", return_value=None)
63+
mocker.patch.object(
64+
data_utils.tiled_datasets,
65+
"get_data_uri_by_trimmed_uri",
66+
return_value="http://example.com/data/project/sample",
67+
)
68+
mocker.patch.object(
69+
data_utils.tiled_datasets,
70+
"get_data_sequence_by_trimmed_uri",
71+
return_value=SimpleNamespace(access_blob={}),
72+
)
73+
74+
all_annotations = [
75+
{
76+
"class_id": "class-1",
77+
"label": "Class 1",
78+
"color": "#ffffff",
79+
"annotations": {
80+
"0": [
81+
{
82+
"type": "rect",
83+
"x0": 0,
84+
"y0": 0,
85+
"x1": 1,
86+
"y1": 1,
87+
}
88+
]
89+
},
90+
}
91+
]
92+
93+
uri, num_classes, message = mask_handler.save_annotations_data(
94+
global_store={"image_shapes": [(4, 4)]},
95+
all_annotations=all_annotations,
96+
trimmed_uri="project/sample",
97+
)
98+
99+
assert uri is not None
100+
assert num_classes == 1
101+
assert message == "Annotations saved successfully."
102+
assert "project/sample" in uri
103+
104+
user_container = root_container[data_utils.USER_NAME]
105+
project_container = user_container["project"]["sample"]
106+
assert len(list(project_container.keys())) == 1
107+
108+
saved_hash = next(iter(project_container.keys()))
109+
saved_container = project_container[saved_hash]
110+
assert saved_container.metadata["project_name"] == "project/sample"
111+
assert "mask" in saved_container.keys()

utils/data_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import os
34
import traceback
45
from urllib.parse import urlparse, urlunparse
@@ -15,6 +16,12 @@
1516

1617
load_dotenv()
1718

19+
# Setup basic logging
20+
logging.basicConfig(
21+
level=logging.INFO,
22+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
23+
)
24+
1825
DATA_TILED_URI = os.getenv("DATA_TILED_URI")
1926
DATA_TILED_API_KEY = os.getenv("DATA_TILED_API_KEY")
2027
MASK_TILED_URI = os.getenv("MASK_TILED_URI")
@@ -354,6 +361,10 @@ def save_annotations_data(self, global_store, all_annotations, trimmed_uri):
354361
key=annotations_hash, metadata=metadata
355362
)
356363
mask = last_container.write_array(key="mask", array=mask)
364+
image_client = tiled_datasets.get_data_sequence_by_trimmed_uri(trimmed_uri)
365+
# match the new container and new array access tags to the original data_client
366+
copy_tiled_access_info(image_client, last_container)
367+
copy_tiled_access_info(image_client, mask)
357368
else:
358369
last_container = last_container[annotations_hash]
359370
return (
@@ -439,3 +450,20 @@ def assemble_io_parameters_from_uris(data_uri, mask_uri):
439450
"seg_tiled_uri": SEG_TILED_URI,
440451
}
441452
return io_parameters
453+
454+
455+
def copy_tiled_access_info(source_client, target_client):
456+
"""
457+
This function copies the access information from a source Tiled URI to a target Tiled client.
458+
Input:
459+
source_uri: str, The URI of the source Tiled resource.
460+
target_client: Tiled client object, The target Tiled client to copy access info to.
461+
Output:
462+
None
463+
"""
464+
access_blob = source_client.access_blob
465+
if access_blob and access_blob.get("tags") is not None:
466+
target_client.replace_metadata(access_tags=access_blob["tags"])
467+
logging.info(
468+
f"Tiled access information copied successfully {source_client} {source_client.access_blob}."
469+
)

0 commit comments

Comments
 (0)